Skip to content

Commit

Permalink
Process several set values for each distribution key (#625)
Browse files Browse the repository at this point in the history
* Add another test for CTE

* Forbid setting several values for distribution key

* Process several values for each distribution key

* Reformatting single_shard_joins regress test

* Fixes

* Removed redundant test in proxy_routing_test.go
  • Loading branch information
EinKrebs authored Apr 25, 2024
1 parent f19a1ac commit 5666ffd
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 52 deletions.
111 changes: 62 additions & 49 deletions router/qrouter/proxy_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type RoutingMetadataContext struct {
// SELECT * FROM a join b WHERE a.c1 = <val> and a.c2 = <val>
// can be routed with different rules
rels map[RelationFQN]struct{}
exprs map[RelationFQN]map[string]string
exprs map[RelationFQN]map[string][]string

// cached CTE names
cteNames map[string]struct{}
Expand All @@ -65,7 +65,7 @@ func NewRoutingMetadataContext(params [][]byte, paramsFormatCodes []int16) *Rout
rels: map[RelationFQN]struct{}{},
cteNames: map[string]struct{}{},
tableAliases: map[string]RelationFQN{},
exprs: map[RelationFQN]map[string]string{},
exprs: map[RelationFQN]map[string][]string{},
unparsed_columns: map[string]struct{}{},
params: params,
}
Expand Down Expand Up @@ -97,17 +97,21 @@ func (meta *RoutingMetadataContext) RFQNIsCTE(resolvedRelation RelationFQN) bool
}

// TODO : unit tests
func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN, colname string, expr string) {
func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN, colname string, expr string) error {
if meta.RFQNIsCTE(resolvedRelation) {
// CTE, skip
return
return nil
}
meta.rels[resolvedRelation] = struct{}{}
if _, ok := meta.exprs[resolvedRelation]; !ok {
meta.exprs[resolvedRelation] = map[string]string{}
meta.exprs[resolvedRelation] = map[string][]string{}
}
delete(meta.unparsed_columns, colname)
meta.exprs[resolvedRelation][colname] = expr
if _, ok := meta.exprs[resolvedRelation][colname]; !ok {
meta.exprs[resolvedRelation][colname] = make([]string, 0)
}
meta.exprs[resolvedRelation][colname] = append(meta.exprs[resolvedRelation][colname], expr)
return nil
}

// TODO : unit tests
Expand Down Expand Up @@ -185,11 +189,11 @@ func (qr *ProxyQrouter) DeparseKeyWithRangesInternal(_ context.Context, key stri
return nil, FailedToFindKeyRange
}

func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMetadataContext, resolvedRelation RelationFQN, colname, value string) {
func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMetadataContext, resolvedRelation RelationFQN, colname, value string) error {

/* do not process non-distributed relations or columns not from relation distribution key */
if ds, err := qr.Mgr().GetRelationDistribution(context.TODO(), resolvedRelation.RelationName); err != nil {
return
return nil
} else {
// TODO: optimize
ok := false
Expand All @@ -201,12 +205,12 @@ func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMeta
}
if !ok {
// some junk column
return
return nil
}
}

// will not work not ints
meta.RecordConstExpr(resolvedRelation, colname, value)
return meta.RecordConstExpr(resolvedRelation, colname, value)
}

// TODO : unit tests
Expand All @@ -232,30 +236,27 @@ func (qr *ProxyQrouter) RecordDistributionKeyExprOnRFQN(meta *RoutingMetadataCon
// ??? protoc violation
}

qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(routeParam))
return nil
return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(routeParam))
case *lyx.AExprSConst:
qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(e.Value))
return nil
return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(e.Value))
case *lyx.AExprIConst:
val := fmt.Sprintf("%d", e.Value)
qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(val))
return nil
return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(val))
default:
return ComplexQuery
}
}

func (qr *ProxyQrouter) RecordDistributionKeyColumnValue(meta *RoutingMetadataContext, alias, colname, value string) {
func (qr *ProxyQrouter) RecordDistributionKeyColumnValue(meta *RoutingMetadataContext, alias, colname, value string) error {

resolvedRelation, err := meta.ResolveRelationByAlias(alias)
if err != nil {
// failed to resolve relation, skip column
meta.unparsed_columns[colname] = struct{}{}
return
return nil
}

qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, value)
return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, value)
}

// routeByClause de-parses sharding column-value pair from Where clause of the query
Expand All @@ -281,27 +282,37 @@ func (qr *ProxyQrouter) routeByClause(ctx context.Context, expr lyx.Node, meta *
switch rght := texpr.Right.(type) {
case *lyx.ParamRef:
if rght.Number <= len(meta.params) {
qr.RecordDistributionKeyColumnValue(meta, alias, colname, string(meta.params[rght.Number-1]))
if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, string(meta.params[rght.Number-1])); err != nil {
return err
}
}
// else error out?
case *lyx.AExprSConst:
// TBD: postpone routing from here to root of parsing tree
qr.RecordDistributionKeyColumnValue(meta, alias, colname, rght.Value)
if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, rght.Value); err != nil {
return err
}
case *lyx.AExprIConst:
// TBD: postpone routing from here to root of parsing tree
// maybe expimely inefficient. Will be fixed in SPQR-2.0
qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", rght.Value))
if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", rght.Value)); err != nil {
return err
}
case *lyx.AExprList:
if len(rght.List) != 0 {
expr := rght.List[0]
switch bexpr := expr.(type) {
case *lyx.AExprSConst:
// TBD: postpone routing from here to root of parsing tree
qr.RecordDistributionKeyColumnValue(meta, alias, colname, bexpr.Value)
if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, bexpr.Value); err != nil {
return err
}
case *lyx.AExprIConst:
// TBD: postpone routing from here to root of parsing tree
// maybe expimely inefficient. Will be fixed in SPQR-2.0
qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", bexpr.Value))
if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", bexpr.Value)); err != nil {
return err
}
}
}
case *lyx.FuncApplication:
Expand Down Expand Up @@ -900,7 +911,7 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s

ok := true

var hashedKey []byte
var hashedKeys [][]byte

// TODO: multi-column routing. This works only for one-dim routing
for i := 0; i < len(distrKey); i++ {
Expand All @@ -913,45 +924,47 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s

col := distrKey[i].Column

val, valOk := meta.exprs[rfqn][col]
vals, valOk := meta.exprs[rfqn][col]
if !valOk {
ok = false
break
}

hashedKey, err = hashfunction.ApplyHashFunction([]byte(val), hf)

spqrlog.Zero.Debug().Str("key", meta.exprs[rfqn][col]).Str("hashed key", string(hashedKey)).Msg("applying hash function on key")
hashedKeys = make([][]byte, len(vals))
for i, val := range vals {
hashedKeys[i], err = hashfunction.ApplyHashFunction([]byte(val), hf)
spqrlog.Zero.Debug().Str("key", meta.exprs[rfqn][col][i]).Str("hashed key", string(hashedKeys[i])).Msg("applying hash function on key")

if err != nil {
spqrlog.Zero.Debug().Err(err).Msg("failed to apply hash function")
ok = false
break
if err != nil {
spqrlog.Zero.Debug().Err(err).Msg("failed to apply hash function")
ok = false
break
}
}
}

if !ok {
// skip this relation
continue
}
for _, hashedKey := range hashedKeys {
currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs)
if err != nil {
route_err = err
spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error")
continue
}

currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs)
if err != nil {
route_err = err
spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error")
continue
}

spqrlog.Zero.Debug().
Interface("currroute", currroute).
Str("table", rfqn.RelationName).
Msg("calculated route for table/cols")

route = routingstate.Combine(route, routingstate.ShardMatchState{
Route: currroute,
TargetSessionAttrs: tsa,
})
spqrlog.Zero.Debug().
Interface("currroute", currroute).
Str("table", rfqn.RelationName).
Msg("calculated route for table/cols")

route = routingstate.Combine(route, routingstate.ShardMatchState{
Route: currroute,
TargetSessionAttrs: tsa,
})
}
}
if route == nil && route_err != nil {
return nil, route_err
Expand Down
67 changes: 66 additions & 1 deletion router/qrouter/proxy_routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,61 @@ func TestCTE(t *testing.T) {
TargetSessionAttrs: "any",
},
},
{
query: `
WITH xxxx AS (
SELECT * from t where i = 1
),
zzzz AS (
UPDATE t
SET a = 0
WHERE i = 12
)
SELECT * FROM xxxx;
`,
err: nil,
exp: routingstate.SkipRoutingState{},
},
{
query: `
WITH xxxx AS (
SELECT * from t where i = 1
),
zzzz AS (
UPDATE t
SET a = 0
WHERE i = 2
)
SELECT * FROM xxxx;
`,
err: nil,
exp: routingstate.ShardMatchState{
Route: &routingstate.DataShardRoute{
Shkey: kr.ShardKey{
Name: "sh1",
},
Matchedkr: &kr.KeyRange{
ShardID: "sh1",
ID: "id1",
Distribution: distribution,
LowerBound: []byte("1"),
},
},
TargetSessionAttrs: "any",
},
},
} {
parserRes, err := lyx.Parse(tt.query)

assert.NoError(err, "query %s", tt.query)

tmp, err := pr.Route(context.TODO(), parserRes, session.NewDummyHandler(distribution))

assert.NoError(err, "query %s", tt.query)
if tt.err == nil {
assert.NoError(err, "query %s", tt.query)
} else {
assert.Error(err, "query %s", tt.query)
}

assert.Equal(tt.exp, tmp, tt.query)
}
Expand Down Expand Up @@ -612,6 +659,24 @@ func TestSingleShard(t *testing.T) {
err: nil,
},

{
query: "SELECT * FROM t WHERE i = 12 AND j = 1;",
exp: routingstate.ShardMatchState{
Route: &routingstate.DataShardRoute{
Shkey: kr.ShardKey{
Name: "sh2",
},
Matchedkr: &kr.KeyRange{
ShardID: "sh2",
ID: "id2",
Distribution: distribution,
LowerBound: []byte("11"),
},
},
TargetSessionAttrs: "any",
},
err: nil,
},
{
query: "SELECT * FROM t WHERE i = 12 UNION ALL SELECT * FROM xxmixed WHERE i = 22;",
exp: routingstate.ShardMatchState{
Expand Down
2 changes: 1 addition & 1 deletion test/regress/tests/router/expected/single_shard_joins.out
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ NOTICE: send query to shard(s) : sh2
12 | 13
(2 rows)

SELECT * FROM sshjt1 WHERE i = 12 AND j =1;
SELECT * FROM sshjt1 WHERE i = 12 AND j = 1;
NOTICE: send query to shard(s) : sh2
i | j
---+---
Expand Down
2 changes: 1 addition & 1 deletion test/regress/tests/router/sql/single_shard_joins.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ INSERT INTO sshjt1 (i, j) VALUES(12, 12);
INSERT INTO sshjt1 (i, j) VALUES(12, 13);

SELECT * FROM sshjt1 WHERE i = 12;
SELECT * FROM sshjt1 WHERE i = 12 AND j =1;
SELECT * FROM sshjt1 WHERE i = 12 AND j = 1;

SELECT * FROM sshjt1 a join sshjt1 b WHERE a.i = 12 ON TRUE;
SELECT * FROM sshjt1 a join sshjt1 b ON TRUE WHERE a.i = 12;
Expand Down

0 comments on commit 5666ffd

Please sign in to comment.