diff --git a/router/qrouter/proxy_routing.go b/router/qrouter/proxy_routing.go index b891fc6de..b92baf050 100644 --- a/router/qrouter/proxy_routing.go +++ b/router/qrouter/proxy_routing.go @@ -97,17 +97,21 @@ func (meta *RoutingMetadataContext) RFQNIsCTE(resolvedRelation RelationFQN) bool } // TODO : unit tests -func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN, colname string, expr string) { +func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN, colname string, expr string) error { if meta.RFQNIsCTE(resolvedRelation) { // CTE, skip - return + return nil } meta.rels[resolvedRelation] = struct{}{} if _, ok := meta.exprs[resolvedRelation]; !ok { meta.exprs[resolvedRelation] = map[string]string{} - } // TODO: else branch + } delete(meta.unparsed_columns, colname) + if curExpr, ok := meta.exprs[resolvedRelation][colname]; ok && curExpr != expr { + return spqrerror.Newf(spqrerror.SPQR_COMPLEX_QUERY, "several different values for distribution key") + } meta.exprs[resolvedRelation][colname] = expr + return nil } // TODO : unit tests @@ -185,11 +189,11 @@ func (qr *ProxyQrouter) DeparseKeyWithRangesInternal(_ context.Context, key stri return nil, FailedToFindKeyRange } -func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMetadataContext, resolvedRelation RelationFQN, colname, value string) { +func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMetadataContext, resolvedRelation RelationFQN, colname, value string) error { /* do not process non-distributed relations or columns not from relation distribution key */ if ds, err := qr.Mgr().GetRelationDistribution(context.TODO(), resolvedRelation.RelationName); err != nil { - return + return nil } else { // TODO: optimize ok := false @@ -201,12 +205,12 @@ func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMeta } if !ok { // some junk column - return + return nil } } // will not work not ints - meta.RecordConstExpr(resolvedRelation, colname, value) + return meta.RecordConstExpr(resolvedRelation, colname, value) } // TODO : unit tests @@ -232,30 +236,27 @@ func (qr *ProxyQrouter) RecordDistributionKeyExprOnRFQN(meta *RoutingMetadataCon // ??? protoc violation } - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(routeParam)) - return nil + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(routeParam)) case *lyx.AExprSConst: - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(e.Value)) - return nil + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(e.Value)) case *lyx.AExprIConst: val := fmt.Sprintf("%d", e.Value) - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(val)) - return nil + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(val)) default: return ComplexQuery } } -func (qr *ProxyQrouter) RecordDistributionKeyColumnValue(meta *RoutingMetadataContext, alias, colname, value string) { +func (qr *ProxyQrouter) RecordDistributionKeyColumnValue(meta *RoutingMetadataContext, alias, colname, value string) error { resolvedRelation, err := meta.ResolveRelationByAlias(alias) if err != nil { // failed to resolve relation, skip column meta.unparsed_columns[colname] = struct{}{} - return + return nil } - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, value) + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, value) } // routeByClause de-parses sharding column-value pair from Where clause of the query @@ -281,27 +282,27 @@ func (qr *ProxyQrouter) routeByClause(ctx context.Context, expr lyx.Node, meta * switch rght := texpr.Right.(type) { case *lyx.ParamRef: if rght.Number <= len(meta.params) { - qr.RecordDistributionKeyColumnValue(meta, alias, colname, string(meta.params[rght.Number-1])) + return qr.RecordDistributionKeyColumnValue(meta, alias, colname, string(meta.params[rght.Number-1])) } // else error out? case *lyx.AExprSConst: // TBD: postpone routing from here to root of parsing tree - qr.RecordDistributionKeyColumnValue(meta, alias, colname, rght.Value) + return qr.RecordDistributionKeyColumnValue(meta, alias, colname, rght.Value) case *lyx.AExprIConst: // TBD: postpone routing from here to root of parsing tree // maybe expimely inefficient. Will be fixed in SPQR-2.0 - qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", rght.Value)) + return qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", rght.Value)) case *lyx.AExprList: if len(rght.List) != 0 { expr := rght.List[0] switch bexpr := expr.(type) { case *lyx.AExprSConst: // TBD: postpone routing from here to root of parsing tree - qr.RecordDistributionKeyColumnValue(meta, alias, colname, bexpr.Value) + return qr.RecordDistributionKeyColumnValue(meta, alias, colname, bexpr.Value) case *lyx.AExprIConst: // TBD: postpone routing from here to root of parsing tree // maybe expimely inefficient. Will be fixed in SPQR-2.0 - qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", bexpr.Value)) + return qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", bexpr.Value)) } } case *lyx.FuncApplication: diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index 7c21fd1f3..4fd347844 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -2,6 +2,7 @@ package qrouter_test import ( "context" + "github.com/pg-sharding/spqr/pkg/models/spqrerror" "testing" "github.com/pg-sharding/spqr/pkg/config" @@ -304,8 +305,8 @@ func TestCTE(t *testing.T) { ) SELECT * FROM xxxx; `, - err: nil, - exp: routingstate.MultiMatchState{}, + err: spqrerror.Newf(spqrerror.SPQR_COMPLEX_QUERY, "several different values for distribution key."), + exp: nil, }, } { parserRes, err := lyx.Parse(tt.query) @@ -314,7 +315,11 @@ func TestCTE(t *testing.T) { tmp, err := pr.Route(context.TODO(), parserRes, session.NewDummyHandler(distribution)) - assert.NoError(err, "query %s", tt.query) + if tt.err == nil { + assert.NoError(err, "query %s", tt.query) + } else { + assert.Error(err, "query %s", tt.query) + } assert.Equal(tt.exp, tmp, tt.query) }