diff --git a/router/qrouter/proxy_routing.go b/router/qrouter/proxy_routing.go index 38bbe38d2..3f07918b7 100644 --- a/router/qrouter/proxy_routing.go +++ b/router/qrouter/proxy_routing.go @@ -43,7 +43,7 @@ type RoutingMetadataContext struct { // SELECT * FROM a join b WHERE a.c1 = and a.c2 = // can be routed with different rules rels map[RelationFQN]struct{} - exprs map[RelationFQN]map[string]string + exprs map[RelationFQN]map[string][]string // cached CTE names cteNames map[string]struct{} @@ -65,7 +65,7 @@ func NewRoutingMetadataContext(params [][]byte, paramsFormatCodes []int16) *Rout rels: map[RelationFQN]struct{}{}, cteNames: map[string]struct{}{}, tableAliases: map[string]RelationFQN{}, - exprs: map[RelationFQN]map[string]string{}, + exprs: map[RelationFQN]map[string][]string{}, unparsed_columns: map[string]struct{}{}, params: params, } @@ -104,13 +104,13 @@ func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN } meta.rels[resolvedRelation] = struct{}{} if _, ok := meta.exprs[resolvedRelation]; !ok { - meta.exprs[resolvedRelation] = map[string]string{} + meta.exprs[resolvedRelation] = map[string][]string{} } delete(meta.unparsed_columns, colname) - if curExpr, ok := meta.exprs[resolvedRelation][colname]; ok && curExpr != expr { - return spqrerror.Newf(spqrerror.SPQR_COMPLEX_QUERY, "several different values for distribution key") + if _, ok := meta.exprs[resolvedRelation][colname]; !ok { + meta.exprs[resolvedRelation][colname] = make([]string, 0) } - meta.exprs[resolvedRelation][colname] = expr + meta.exprs[resolvedRelation][colname] = append(meta.exprs[resolvedRelation][colname], expr) return nil } @@ -881,7 +881,7 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s ok := true - var hashedKey []byte + var hashedKeys [][]byte // TODO: multi-column routing. This works only for one-dim routing for i := 0; i < len(distrKey); i++ { @@ -894,20 +894,22 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s col := distrKey[i].Column - val, valOk := meta.exprs[rfqn][col] + vals, valOk := meta.exprs[rfqn][col] if !valOk { ok = false break } - hashedKey, err = hashfunction.ApplyHashFunction([]byte(val), hf) + hashedKeys = make([][]byte, len(vals)) + for i, val := range vals { + hashedKeys[i], err = hashfunction.ApplyHashFunction([]byte(val), hf) + spqrlog.Zero.Debug().Str("key", meta.exprs[rfqn][col][i]).Str("hashed key", string(hashedKeys[i])).Msg("applying hash function on key") - spqrlog.Zero.Debug().Str("key", meta.exprs[rfqn][col]).Str("hashed key", string(hashedKey)).Msg("applying hash function on key") - - if err != nil { - spqrlog.Zero.Debug().Err(err).Msg("failed to apply hash function") - ok = false - break + if err != nil { + spqrlog.Zero.Debug().Err(err).Msg("failed to apply hash function") + ok = false + break + } } } @@ -915,24 +917,24 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s // skip this relation continue } + for _, hashedKey := range hashedKeys { + currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) + if err != nil { + route_err = err + spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error") + continue + } - currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) - if err != nil { - route_err = err - spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error") - continue - } - - spqrlog.Zero.Debug(). - Interface("currroute", currroute). - Str("table", rfqn.RelationName). - Msg("calculated route for table/cols") - - route = routingstate.Combine(route, routingstate.ShardMatchState{ - Route: currroute, - TargetSessionAttrs: tsa, - }) + spqrlog.Zero.Debug(). + Interface("currroute", currroute). + Str("table", rfqn.RelationName). + Msg("calculated route for table/cols") + route = routingstate.Combine(route, routingstate.ShardMatchState{ + Route: currroute, + TargetSessionAttrs: tsa, + }) + } } if route == nil && route_err != nil { return nil, route_err diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index 4ab5ffe25..3505bf632 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -2,7 +2,6 @@ package qrouter_test import ( "context" - "github.com/pg-sharding/spqr/pkg/models/spqrerror" "testing" "github.com/pg-sharding/spqr/pkg/config" @@ -333,8 +332,36 @@ func TestCTE(t *testing.T) { ) SELECT * FROM xxxx; `, - err: spqrerror.Newf(spqrerror.SPQR_COMPLEX_QUERY, "several different values for distribution key."), - exp: nil, + err: nil, + exp: routingstate.SkipRoutingState{}, + }, + { + query: ` + WITH xxxx AS ( + SELECT * from t where i = 1 + ), + zzzz AS ( + UPDATE t + SET a = 0 + WHERE i = 2 + ) + SELECT * FROM xxxx; + `, + err: nil, + exp: routingstate.ShardMatchState{ + Route: &routingstate.DataShardRoute{ + Shkey: kr.ShardKey{ + Name: "sh1", + }, + Matchedkr: &kr.KeyRange{ + ShardID: "sh1", + ID: "id1", + Distribution: distribution, + LowerBound: []byte("1"), + }, + }, + TargetSessionAttrs: "any", + }, }, } { parserRes, err := lyx.Parse(tt.query)