Skip to content

Commit

Permalink
Refactor protospecs (#790)
Browse files Browse the repository at this point in the history
* Refactor protospecs

* Fix feature tests
  • Loading branch information
EinKrebs authored Oct 31, 2024
1 parent b612e96 commit 026d31c
Show file tree
Hide file tree
Showing 44 changed files with 1,198 additions and 2,665 deletions.
6 changes: 3 additions & 3 deletions balancer/provider/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ func (b *BalancerImpl) getTasks(ctx context.Context, shardFrom *ShardMetrics, kr

func (b *BalancerImpl) getCurrentTaskGroupFromQDB(ctx context.Context) (group *tasks.TaskGroup, err error) {
tasksService := protos.NewTasksServiceClient(b.coordinatorConn)
resp, err := tasksService.GetTaskGroup(ctx, &protos.GetTaskGroupRequest{})
resp, err := tasksService.GetTaskGroup(ctx, nil)
if err != nil {
return nil, err
}
Expand All @@ -623,7 +623,7 @@ func (b *BalancerImpl) syncTaskGroupWithQDB(ctx context.Context, group *tasks.Ta

func (b *BalancerImpl) removeTaskGroupFromQDB(ctx context.Context) error {
tasksService := protos.NewTasksServiceClient(b.coordinatorConn)
_, err := tasksService.RemoveTaskGroup(ctx, &protos.RemoveTaskGroupRequest{})
_, err := tasksService.RemoveTaskGroup(ctx, nil)
return err
}

Expand Down Expand Up @@ -709,7 +709,7 @@ func (b *BalancerImpl) executeTasks(ctx context.Context, group *tasks.TaskGroup)
func (b *BalancerImpl) updateKeyRanges(ctx context.Context) error {
keyRangeService := protos.NewKeyRangeServiceClient(b.coordinatorConn)
distrService := protos.NewDistributionServiceClient(b.coordinatorConn)
keyRangesProto, err := keyRangeService.ListAllKeyRanges(ctx, &protos.ListAllKeyRangesRequest{})
keyRangesProto, err := keyRangeService.ListAllKeyRanges(ctx, nil)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/coordctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ var listRouterCmd = &cobra.Command{
}

rCl := protos.NewRouterServiceClient(cc)
if resp, err := rCl.ListRouters(context.Background(), &protos.ListRoutersRequest{}); err == nil {
if resp, err := rCl.ListRouters(context.Background(), nil); err == nil {
fmt.Printf("-------------------------------------\n")
fmt.Printf("%d routers found\n", len(resp.Routers))

Expand Down Expand Up @@ -186,7 +186,7 @@ var listShardCmd = &cobra.Command{
}

rCl := protos.NewShardServiceClient(cc)
if resp, err := rCl.ListShards(context.Background(), &protos.ListShardsRequest{}); err == nil {
if resp, err := rCl.ListShards(context.Background(), nil); err == nil {
fmt.Printf("-------------------------------------\n")
fmt.Printf("%d shards found\n", len(resp.Shards))

Expand Down
4 changes: 2 additions & 2 deletions cmd/spqrdump/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func DumpKeyRanges() error {

rCl := protos.NewKeyRangeServiceClient(cc)
dCl := protos.NewDistributionServiceClient(cc)
if keys, err := rCl.ListAllKeyRanges(context.Background(), &protos.ListAllKeyRangesRequest{}); err != nil {
if keys, err := rCl.ListAllKeyRanges(context.Background(), nil); err != nil {
spqrlog.Zero.Error().
Err(err).
Msg("failed to dump endpoint rules")
Expand Down Expand Up @@ -204,7 +204,7 @@ func DumpDistributions() error {
}

rCl := protos.NewDistributionServiceClient(cc)
if dss, err := rCl.ListDistributions(context.Background(), &protos.ListDistributionsRequest{}); err != nil {
if dss, err := rCl.ListDistributions(context.Background(), nil); err != nil {
spqrlog.Zero.Error().
Err(err).
Msg("failed to dump endpoint distributions")
Expand Down
20 changes: 10 additions & 10 deletions coordinator/provider/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (ci grpcConnectionIterator) ClientPoolForeach(cb func(client client.ClientI
rrClient := routerproto.NewClientInfoServiceClient(cc)

spqrlog.Zero.Debug().Msg("fetch clients with grpc")
resp, err := rrClient.ListClients(ctx, &routerproto.ListClientsRequest{})
resp, err := rrClient.ListClients(ctx, nil)
if err != nil {
spqrlog.Zero.Error().Msg("error fetching clients with grpc")
return err
Expand Down Expand Up @@ -128,7 +128,7 @@ func (ci grpcConnectionIterator) ForEach(cb func(sh shard.Shardinfo) error) erro
rrBackConn := routerproto.NewBackendConnectionsServiceClient(cc)

spqrlog.Zero.Debug().Msg("fetch clients with grpc")
resp, err := rrBackConn.ListBackendConnections(ctx, &routerproto.ListBackendConnectionsRequest{})
resp, err := rrBackConn.ListBackendConnections(ctx, nil)
if err != nil {
spqrlog.Zero.Error().Msg("error fetching clients with grpc")
return err
Expand All @@ -151,7 +151,7 @@ func (ci grpcConnectionIterator) ForEachPool(cb func(p pool.Pool) error) error {
rrBackConn := routerproto.NewPoolServiceClient(cc)

spqrlog.Zero.Debug().Msg("fetch pools with grpc")
resp, err := rrBackConn.ListPools(ctx, &routerproto.ListPoolsRequest{})
resp, err := rrBackConn.ListPools(ctx, nil)
if err != nil {
spqrlog.Zero.Error().Msg("error fetching pools with grpc")
return err
Expand Down Expand Up @@ -230,7 +230,7 @@ func (qc *qdbCoordinator) watchRouters(ctx context.Context) {

rrClient := routerproto.NewTopologyServiceClient(cc)

resp, err := rrClient.GetRouterStatus(ctx, &routerproto.GetRouterStatusRequest{})
resp, err := rrClient.GetRouterStatus(ctx, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -694,7 +694,7 @@ func (qc *qdbCoordinator) DropKeyRangeAll(ctx context.Context) error {

if err := qc.traverseRouters(ctx, func(cc *grpc.ClientConn) error {
cl := routerproto.NewKeyRangeServiceClient(cc)
resp, err := cl.DropAllKeyRanges(ctx, &routerproto.DropAllKeyRangesRequest{})
resp, err := cl.DropAllKeyRanges(ctx, nil)
spqrlog.Zero.Debug().Err(err).
Interface("response", resp).
Msg("drop key range response")
Expand Down Expand Up @@ -987,7 +987,7 @@ func (qc *qdbCoordinator) SyncRouterMetadata(ctx context.Context, qRouter *topol
if err != nil {
return err
}
resp, err := dsCl.ListDistributions(ctx, &routerproto.ListDistributionsRequest{})
resp, err := dsCl.ListDistributions(ctx, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -1021,7 +1021,7 @@ func (qc *qdbCoordinator) SyncRouterMetadata(ctx context.Context, qRouter *topol
if err != nil {
return err
}
if _, err = krClient.DropAllKeyRanges(ctx, &routerproto.DropAllKeyRangesRequest{}); err != nil {
if _, err = krClient.DropAllKeyRanges(ctx, nil); err != nil {
return err
}

Expand Down Expand Up @@ -1051,7 +1051,7 @@ func (qc *qdbCoordinator) SyncRouterMetadata(ctx context.Context, qRouter *topol
return err
}

if resp, err := rCl.OpenRouter(ctx, &routerproto.OpenRouterRequest{}); err != nil {
if resp, err := rCl.OpenRouter(ctx, nil); err != nil {
return err
} else {
spqrlog.Zero.Debug().
Expand Down Expand Up @@ -1084,7 +1084,7 @@ func (qc *qdbCoordinator) SyncRouterCoordinatorAddress(ctx context.Context, qRou
return err
}

if resp, err := rCl.OpenRouter(ctx, &routerproto.OpenRouterRequest{}); err != nil {
if resp, err := rCl.OpenRouter(ctx, nil); err != nil {
return err
} else {
spqrlog.Zero.Debug().
Expand All @@ -1110,7 +1110,7 @@ func (qc *qdbCoordinator) RegisterRouter(ctx context.Context, r *topology.Router
}
defer conn.Close()
cl := routerproto.NewTopologyServiceClient(conn)
_, err = cl.GetRouterStatus(ctx, &routerproto.GetRouterStatusRequest{})
_, err = cl.GetRouterStatus(ctx, nil)
if err != nil {
return spqrerror.Newf(spqrerror.SPQR_CONNECTION_ERROR, "failed to ping router: %s", err)
}
Expand Down
20 changes: 11 additions & 9 deletions coordinator/provider/distributions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package provider

import (
"context"

"github.com/pg-sharding/spqr/coordinator"
"github.com/pg-sharding/spqr/pkg/models/distributions"
protos "github.com/pg-sharding/spqr/pkg/protos"
"google.golang.org/protobuf/types/known/emptypb"
)

type DistributionsServer struct {
Expand All @@ -21,25 +23,25 @@ func NewDistributionServer(impl coordinator.Coordinator) *DistributionsServer {

var _ protos.DistributionServiceServer = &DistributionsServer{}

func (d *DistributionsServer) CreateDistribution(ctx context.Context, req *protos.CreateDistributionRequest) (*protos.CreateDistributionReply, error) {
func (d *DistributionsServer) CreateDistribution(ctx context.Context, req *protos.CreateDistributionRequest) (*emptypb.Empty, error) {
for _, ds := range req.Distributions {
if err := d.impl.CreateDistribution(ctx, distributions.DistributionFromProto(ds)); err != nil {
return nil, err
}
}
return &protos.CreateDistributionReply{}, nil
return nil, nil
}

func (d *DistributionsServer) DropDistribution(ctx context.Context, req *protos.DropDistributionRequest) (*protos.DropDistributionReply, error) {
func (d *DistributionsServer) DropDistribution(ctx context.Context, req *protos.DropDistributionRequest) (*emptypb.Empty, error) {
for _, id := range req.GetIds() {
if err := d.impl.DropDistribution(ctx, id); err != nil {
return nil, err
}
}
return &protos.DropDistributionReply{}, nil
return nil, nil
}

func (d *DistributionsServer) ListDistributions(ctx context.Context, req *protos.ListDistributionsRequest) (*protos.ListDistributionsReply, error) {
func (d *DistributionsServer) ListDistributions(ctx context.Context, _ *emptypb.Empty) (*protos.ListDistributionsReply, error) {
dss, err := d.impl.ListDistributions(ctx)
if err != nil {
return nil, err
Expand All @@ -55,8 +57,8 @@ func (d *DistributionsServer) ListDistributions(ctx context.Context, req *protos
}, nil
}

func (d *DistributionsServer) AlterDistributionAttach(ctx context.Context, req *protos.AlterDistributionAttachRequest) (*protos.AlterDistributionAttachReply, error) {
return &protos.AlterDistributionAttachReply{}, d.impl.AlterDistributionAttach(ctx, req.GetId(), func() []*distributions.DistributedRelation {
func (d *DistributionsServer) AlterDistributionAttach(ctx context.Context, req *protos.AlterDistributionAttachRequest) (*emptypb.Empty, error) {
return nil, d.impl.AlterDistributionAttach(ctx, req.GetId(), func() []*distributions.DistributedRelation {
res := make([]*distributions.DistributedRelation, len(req.GetRelations()))
for i, rel := range req.GetRelations() {
res[i] = distributions.DistributedRelationFromProto(rel)
Expand All @@ -65,13 +67,13 @@ func (d *DistributionsServer) AlterDistributionAttach(ctx context.Context, req *
}())
}

func (d *DistributionsServer) AlterDistributionDetach(ctx context.Context, req *protos.AlterDistributionDetachRequest) (*protos.AlterDistributionDetachReply, error) {
func (d *DistributionsServer) AlterDistributionDetach(ctx context.Context, req *protos.AlterDistributionDetachRequest) (*emptypb.Empty, error) {
for _, rel := range req.GetRelNames() {
if err := d.impl.AlterDistributionDetach(ctx, req.GetId(), rel); err != nil {
return nil, err
}
}
return &protos.AlterDistributionDetachReply{}, nil
return nil, nil
}

func (d *DistributionsServer) GetDistribution(ctx context.Context, req *protos.GetDistributionRequest) (*protos.GetDistributionReply, error) {
Expand Down
5 changes: 3 additions & 2 deletions coordinator/provider/keyranges.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/pg-sharding/spqr/pkg/models/spqrerror"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/pg-sharding/spqr/coordinator"
"github.com/pg-sharding/spqr/pkg/models/kr"
Expand All @@ -17,7 +18,7 @@ type CoordinatorService struct {
}

// DropAllKeyRanges implements proto.KeyRangeServiceServer.
func (c *CoordinatorService) DropAllKeyRanges(ctx context.Context, request *protos.DropAllKeyRangesRequest) (*protos.DropAllKeyRangesResponse, error) {
func (c *CoordinatorService) DropAllKeyRanges(ctx context.Context, request *emptypb.Empty) (*protos.DropAllKeyRangesResponse, error) {
err := c.impl.DropKeyRangeAll(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -125,7 +126,7 @@ func (c *CoordinatorService) ListKeyRange(ctx context.Context, request *protos.L
}, nil
}

func (c *CoordinatorService) ListAllKeyRanges(ctx context.Context, _ *protos.ListAllKeyRangesRequest) (*protos.KeyRangeReply, error) {
func (c *CoordinatorService) ListAllKeyRanges(ctx context.Context, _ *emptypb.Empty) (*protos.KeyRangeReply, error) {
krsDb, err := c.impl.ListAllKeyRanges(ctx)
if err != nil {
return nil, err
Expand Down
11 changes: 6 additions & 5 deletions coordinator/provider/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/pg-sharding/spqr/pkg/models/topology"
protos "github.com/pg-sharding/spqr/pkg/protos"
"github.com/pg-sharding/spqr/pkg/spqrlog"
"google.golang.org/protobuf/types/known/emptypb"
)

type RouterService struct {
Expand All @@ -16,7 +17,7 @@ type RouterService struct {
}

// TODO : unit tests
func (r RouterService) ListRouters(ctx context.Context, request *protos.ListRoutersRequest) (*protos.ListRoutersReply, error) {
func (r RouterService) ListRouters(ctx context.Context, _ *emptypb.Empty) (*protos.ListRoutersReply, error) {
routers, err := r.impl.ListRouters(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -48,27 +49,27 @@ func (r RouterService) AddRouter(ctx context.Context, request *protos.AddRouterR
}

// TODO : unit tests
func (r RouterService) RemoveRouter(ctx context.Context, request *protos.RemoveRouterRequest) (*protos.RemoveRouterReply, error) {
func (r RouterService) RemoveRouter(ctx context.Context, request *protos.RemoveRouterRequest) (*emptypb.Empty, error) {
spqrlog.Zero.Debug().
Str("router-id", request.Id).
Msg("unregister router in coordinator")
err := r.impl.UnregisterRouter(ctx, request.Id)
if err != nil {
return nil, err
}
return &protos.RemoveRouterReply{}, nil
return nil, nil
}

// TODO : unit tests
func (r RouterService) SyncMetadata(ctx context.Context, request *protos.SyncMetadataRequest) (*protos.SyncMetadataReply, error) {
func (r RouterService) SyncMetadata(ctx context.Context, request *protos.SyncMetadataRequest) (*emptypb.Empty, error) {
spqrlog.Zero.Debug().
Str("router-id", request.Router.Id).
Msg("sync router metadata in coordinator")
err := r.impl.SyncRouterMetadata(ctx, topology.RouterFromProto(request.Router))
if err != nil {
return nil, err
}
return &protos.SyncMetadataReply{}, nil
return nil, nil
}

var _ protos.RouterServiceServer = &RouterService{}
Expand Down
9 changes: 5 additions & 4 deletions coordinator/provider/shards.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
routerproto "github.com/pg-sharding/spqr/pkg/protos"
"github.com/pg-sharding/spqr/pkg/shard"
"github.com/pg-sharding/spqr/pkg/txstatus"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/pg-sharding/spqr/coordinator"
"github.com/pg-sharding/spqr/pkg/models/datashards"
Expand All @@ -27,23 +28,23 @@ func NewShardServer(impl coordinator.Coordinator) *ShardServer {
var _ protos.ShardServiceServer = &ShardServer{}

// TODO : unit tests
func (s *ShardServer) AddDataShard(ctx context.Context, request *protos.AddShardRequest) (*protos.AddShardReply, error) {
func (s *ShardServer) AddDataShard(ctx context.Context, request *protos.AddShardRequest) (*emptypb.Empty, error) {
newShard := request.GetShard()

if err := s.impl.AddDataShard(ctx, datashards.DataShardFromProto(newShard)); err != nil {
return nil, err
}

return &protos.AddShardReply{}, nil
return nil, nil
}

func (s *ShardServer) AddWorldShard(ctx context.Context, request *protos.AddWorldShardRequest) (*protos.AddShardReply, error) {
func (s *ShardServer) AddWorldShard(ctx context.Context, request *protos.AddWorldShardRequest) (*emptypb.Empty, error) {
panic("implement me")
}

// TODO : unit tests
// TODO: remove ShardRequest.
func (s *ShardServer) ListShards(ctx context.Context, _ *protos.ListShardsRequest) (*protos.ListShardsReply, error) {
func (s *ShardServer) ListShards(ctx context.Context, _ *emptypb.Empty) (*protos.ListShardsReply, error) {
shardList, err := s.impl.ListShards(ctx)
if err != nil {
return nil, err
Expand Down
12 changes: 7 additions & 5 deletions coordinator/provider/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package provider

import (
"context"

"github.com/pg-sharding/spqr/coordinator"
"github.com/pg-sharding/spqr/pkg/models/tasks"
protos "github.com/pg-sharding/spqr/pkg/protos"
"google.golang.org/protobuf/types/known/emptypb"
)

type TasksServer struct {
Expand All @@ -21,19 +23,19 @@ func NewTasksServer(impl coordinator.Coordinator) *TasksServer {

var _ protos.TasksServiceServer = &TasksServer{}

func (t TasksServer) GetTaskGroup(ctx context.Context, _ *protos.GetTaskGroupRequest) (*protos.GetTaskGroupReply, error) {
func (t TasksServer) GetTaskGroup(ctx context.Context, _ *emptypb.Empty) (*protos.GetTaskGroupReply, error) {
group, err := t.impl.GetTaskGroup(ctx)
if err != nil {
return nil, err
}
return &protos.GetTaskGroupReply{TaskGroup: tasks.TaskGroupToProto(group)}, nil
}

func (t TasksServer) WriteTaskGroup(ctx context.Context, request *protos.WriteTaskGroupRequest) (*protos.WriteTaskGroupReply, error) {
func (t TasksServer) WriteTaskGroup(ctx context.Context, request *protos.WriteTaskGroupRequest) (*emptypb.Empty, error) {
err := t.impl.WriteTaskGroup(ctx, tasks.TaskGroupFromProto(request.TaskGroup))
return &protos.WriteTaskGroupReply{}, err
return nil, err
}

func (t TasksServer) RemoveTaskGroup(ctx context.Context, _ *protos.RemoveTaskGroupRequest) (*protos.RemoveTaskGroupReply, error) {
return &protos.RemoveTaskGroupReply{}, t.impl.RemoveTaskGroup(ctx)
func (t TasksServer) RemoveTaskGroup(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
return nil, t.impl.RemoveTaskGroup(ctx)
}
Loading

0 comments on commit 026d31c

Please sign in to comment.