Skip to content

Commit

Permalink
Change yacc and protos for multidim key ranges. (#545)
Browse files Browse the repository at this point in the history
* Change yacc and protos for multidim key ranges.
  • Loading branch information
reshke authored Jul 8, 2024
1 parent c0dae6c commit 86af555
Show file tree
Hide file tree
Showing 59 changed files with 2,295 additions and 1,075 deletions.
59 changes: 54 additions & 5 deletions balancer/provider/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package provider

import (
"context"
"encoding/binary"
"fmt"
"sort"
"strconv"
"strings"

"github.com/google/uuid"
Expand All @@ -15,6 +17,7 @@ import (
"github.com/pg-sharding/spqr/pkg/models/tasks"
protos "github.com/pg-sharding/spqr/pkg/protos"
"github.com/pg-sharding/spqr/pkg/spqrlog"
"github.com/pg-sharding/spqr/qdb"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
Expand Down Expand Up @@ -347,10 +350,11 @@ func (b *BalancerImpl) getKRCondition(rel *distributions.DistributedRelation, kR
} else {
hashedCol = entry.Column
}
// TODO: fix multidim case
if nextKR != nil {
buf[i] = fmt.Sprintf("%s >= %s AND %s < %s", hashedCol, string(kRange.LowerBound), hashedCol, string(nextKR.LowerBound))
buf[i] = fmt.Sprintf("%s >= %s AND %s < %s", hashedCol, kRange.SendRaw()[0], hashedCol, nextKR.SendRaw()[0])
} else {
buf[i] = fmt.Sprintf("%s >= %s", hashedCol, string(kRange.LowerBound))
buf[i] = fmt.Sprintf("%s >= %s", hashedCol, kRange.SendRaw()[0])
}
}
return strings.Join(buf, " AND "), nil
Expand Down Expand Up @@ -501,6 +505,7 @@ func (b *BalancerImpl) getTasks(ctx context.Context, shardFrom *ShardMetrics, kr
maxCount = count
}
}

var rel *distributions.DistributedRelation = nil
allRels, err := b.getKRRelations(ctx, b.dsToKeyRanges[ds][krInd])
if err != nil {
Expand All @@ -516,6 +521,16 @@ func (b *BalancerImpl) getTasks(ctx context.Context, shardFrom *ShardMetrics, kr
return nil, fmt.Errorf("relation \"%s\" not found", relName)
}

dsService := protos.NewDistributionServiceClient(b.coordinatorConn)

dsS, err := dsService.GetDistribution(ctx, &protos.GetDistributionRequest{
Id: ds,
})

if err != nil {
return nil, err
}

moveCount := min((keyCount+config.BalancerConfig().KeysPerMove-1)/config.BalancerConfig().KeysPerMove, config.BalancerConfig().MaxMoveCount)

counts := make([]int, moveCount)
Expand Down Expand Up @@ -552,12 +567,38 @@ func (b *BalancerImpl) getTasks(ctx context.Context, shardFrom *ShardMetrics, kr
if err := row.Scan(&idx); err != nil {
return nil, err
}

var bound []byte

switch dsS.Distribution.ColumnTypes[0] {
case qdb.ColumnTypeVarchar:
fallthrough
case qdb.ColumnTypeVarcharDeprecated:
bound = []byte(idx)
case qdb.ColumnTypeVarcharHashed:
fallthrough
case qdb.ColumnTypeInteger:
i, err := strconv.ParseInt(idx, 10, 64)
if err != nil {
return nil, err
}
bound = make([]byte, 8)
binary.PutVarint(bound, i)
case qdb.ColumnTypeUinteger:
i, err := strconv.ParseUint(idx, 10, 64)
if err != nil {
return nil, err
}
bound = make([]byte, 8)
binary.PutUvarint(bound, i)
}

groupTasks[len(groupTasks)-1-i] = &tasks.Task{
ShardFromId: shardFrom.ShardId,
ShardToId: shardToId,
KrIdFrom: krId,
KrIdTo: krIdTo,
Bound: []byte(idx),
Bound: bound,
}
totalCount += count
}
Expand Down Expand Up @@ -667,6 +708,7 @@ func (b *BalancerImpl) executeTasks(ctx context.Context, group *tasks.TaskGroup)

func (b *BalancerImpl) updateKeyRanges(ctx context.Context) error {
keyRangeService := protos.NewKeyRangeServiceClient(b.coordinatorConn)
distrService := protos.NewDistributionServiceClient(b.coordinatorConn)
keyRangesProto, err := keyRangeService.ListAllKeyRanges(ctx, &protos.ListAllKeyRangesRequest{})
if err != nil {
return err
Expand All @@ -676,18 +718,25 @@ func (b *BalancerImpl) updateKeyRanges(ctx context.Context) error {
if _, ok := keyRanges[krProto.DistributionId]; !ok {
keyRanges[krProto.DistributionId] = make([]*kr.KeyRange, 0)
}
keyRanges[krProto.DistributionId] = append(keyRanges[krProto.DistributionId], kr.KeyRangeFromProto(krProto))
ds, err := distrService.GetDistribution(ctx, &protos.GetDistributionRequest{
Id: krProto.DistributionId,
})
if err != nil {
return err
}
keyRanges[krProto.DistributionId] = append(keyRanges[krProto.DistributionId], kr.KeyRangeFromProto(krProto, ds.Distribution.ColumnTypes))
}
for _, krs := range keyRanges {
sort.Slice(krs, func(i, j int) bool {
return kr.CmpRangesLess(krs[i].LowerBound, krs[j].LowerBound)
return kr.CmpRangesLess(krs[i].LowerBound, krs[j].LowerBound, krs[j].ColumnTypes)
})
}

b.dsToKeyRanges = keyRanges
b.dsToKrIdx = make(map[string]map[string]int)
b.shardKr = make(map[string][]string)
b.krToDs = make(map[string]string)

for ds, krs := range b.dsToKeyRanges {
for i, krg := range krs {
b.krToDs[krg.ID] = ds
Expand Down
24 changes: 17 additions & 7 deletions cmd/mover/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"context"
"flag"
"fmt"
"github.com/pg-sharding/spqr/pkg/models/distributions"
"io"
"os"
"strings"

"github.com/pg-sharding/spqr/pkg/models/distributions"

"github.com/jackc/pgx/v5"
_ "github.com/lib/pq"
"github.com/pg-sharding/spqr/pkg/models/kr"
Expand Down Expand Up @@ -103,10 +104,10 @@ FROM information_schema.tables;
// TODO: support multi-column move in SPQR2
if nextKeyRange == nil {
qry = fmt.Sprintf("copy (delete from %s WHERE %s >= %s returning *) to stdout", rel.Name,
rel.DistributionKey[0].Column, keyRange.LowerBound)
rel.DistributionKey[0].Column, keyRange.SendRaw()[0])
} else {
qry = fmt.Sprintf("copy (delete from %s WHERE %s >= %s and %s < %s returning *) to stdout", rel.Name,
rel.DistributionKey[0].Column, keyRange.LowerBound, rel.DistributionKey[0].Column, nextKeyRange.LowerBound)
rel.DistributionKey[0].Column, keyRange.SendRaw()[0], rel.DistributionKey[0].Column, nextKeyRange.SendRaw()[0])
}

spqrlog.Zero.Debug().
Expand All @@ -132,6 +133,7 @@ FROM information_schema.tables;
spqrlog.Zero.Debug().Msg("copy cmd executed")
}

/* TODO: handle errors here */
_ = txTo.Commit(ctx)
_ = txFrom.Commit(ctx)
return nil
Expand Down Expand Up @@ -165,7 +167,14 @@ func main() {
spqrlog.Zero.Error().Err(err).Msg("")
return
}
keyRange := kr.KeyRangeFromDB(qdbKr)

ds, err := db.GetDistribution(ctx, qdbKr.DistributionId)
if err != nil {
spqrlog.Zero.Error().Err(err).Msg("")
return
}

keyRange := kr.KeyRangeFromDB(qdbKr, ds.ColTypes)

krs, err := db.ListKeyRanges(ctx, keyRange.Distribution)
if err != nil {
Expand All @@ -176,9 +185,10 @@ func main() {
var nextKeyRange *kr.KeyRange

for _, currkr := range krs {
if kr.CmpRangesLess(keyRange.LowerBound, currkr.LowerBound) {
if nextKeyRange == nil || kr.CmpRangesLess(currkr.LowerBound, nextKeyRange.LowerBound) {
nextKeyRange = kr.KeyRangeFromDB(currkr)
typedKr := kr.KeyRangeFromDB(currkr, ds.ColTypes)
if kr.CmpRangesLess(keyRange.LowerBound, typedKr.LowerBound, ds.ColTypes) {
if nextKeyRange == nil || kr.CmpRangesLess(typedKr.LowerBound, nextKeyRange.LowerBound, ds.ColTypes) {
nextKeyRange = typedKr
}
}
}
Expand Down
21 changes: 16 additions & 5 deletions cmd/spqrdump/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"

"github.com/pg-sharding/spqr/pkg/conn"
"github.com/pg-sharding/spqr/pkg/models/kr"
"github.com/pg-sharding/spqr/pkg/spqrlog"

"github.com/pg-sharding/spqr/pkg/decode"
Expand Down Expand Up @@ -116,14 +117,16 @@ func getconn() (*pgproto3.Frontend, error) {
// TODO : unit tests
func DumpKeyRangesPsql() error {
return dumpPsql("SHOW key_ranges;", func(v *pgproto3.DataRow) (string, error) {
l := string(v.Values[2])
l := v.Values[2]
id := string(v.Values[0])
shard := string(v.Values[1])

return decode.KeyRange(
&protos.KeyRangeInfo{
KeyRange: &protos.KeyRange{LowerBound: l},
ShardId: shard, Krid: id}), nil
&kr.KeyRange{
LowerBound: []interface{}{l},
ID: id,
ShardID: shard,
}), nil
})
}

Expand Down Expand Up @@ -171,13 +174,21 @@ func DumpKeyRanges() error {
}

rCl := protos.NewKeyRangeServiceClient(cc)
dCl := protos.NewDistributionServiceClient(cc)
if keys, err := rCl.ListAllKeyRanges(context.Background(), &protos.ListAllKeyRangesRequest{}); err != nil {
spqrlog.Zero.Error().
Err(err).
Msg("failed to dump endpoint rules")
} else {
for _, krg := range keys.KeyRangesInfo {
fmt.Println(decode.KeyRange(krg))
ds, err := dCl.GetDistribution(context.Background(), &protos.GetDistributionRequest{
Id: krg.DistributionId,
})
if err != nil {
return err
}
krCurr := kr.KeyRangeFromProto(krg, ds.Distribution.ColumnTypes)
fmt.Println(decode.KeyRange(krCurr))
}
}

Expand Down
Loading

0 comments on commit 86af555

Please sign in to comment.