Skip to content

Commit

Permalink
feat(server): create query for list and lookup worker to use
Browse files Browse the repository at this point in the history
  • Loading branch information
irenarindos committed Nov 8, 2024
1 parent 12d8898 commit 1de4022
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ func (s Service) toProto(ctx context.Context, in *server.Worker, opt ...handlers
out.LastStatusTime = in.GetLastStatusTime().GetTimestamp()
}
if outputFields.Has(globals.ActiveConnectionCountField) {
out.ActiveConnectionCount = &wrapperspb.UInt32Value{Value: in.ActiveConnectionCount()}
out.ActiveConnectionCount = &wrapperspb.UInt32Value{Value: in.GetActiveConnectionCount()}
}
if outputFields.Has(globals.ControllerGeneratedActivationToken) && in.ControllerGeneratedActivationToken != "" {
out.ControllerGeneratedActivationToken = &wrapperspb.StringValue{Value: in.ControllerGeneratedActivationToken}
Expand Down
33 changes: 33 additions & 0 deletions internal/server/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,39 @@
package server

const (
listWorkersQuery = `
with connection_count (worker_id, count) as (
select worker_id,
count(1) as count
from session_connection
where closed_reason is null
group by worker_id
)
select w.public_id,
w.scope_id,
w.description,
w.name,
w.address,
w.create_time,
w.update_time,
w.version,
w.last_status_time,
w.type,
w.release_version,
w.operational_state,
w.local_storage_state,
cc.count as active_connection_count,
wt.tags as api_tags,
ct.tags as config_tags
from server_worker w
left join (select worker_id, json_agg(json_build_object('key', key, 'value', value)) as tags from server_worker_api_tag group by worker_id) wt
on w.public_id = wt.worker_id
left join (select worker_id, json_agg(json_build_object('key', key, 'value', value)) as tags from server_worker_config_tag group by worker_id) ct
on w.public_id = ct.worker_id
left join connection_count as cc
on w.public_id = cc.worker_id
`

getStorageBucketCredentialStatesByWorkerId = `
select spsb.public_id as storage_bucket_id,
wsbcs.permission_type, wsbcs.state,
Expand Down
91 changes: 44 additions & 47 deletions internal/server/repository_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package server

import (
"context"
"database/sql"
stderrors "errors"
"fmt"
"strings"
Expand Down Expand Up @@ -117,20 +118,20 @@ func lookupWorker(ctx context.Context, reader db.Reader, id string) (*Worker, er
case id == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "id is empty")
}
wAgg := &workerAggregate{}
wAgg.PublicId = id
err := reader.LookupById(ctx, wAgg)

lookupQuery := fmt.Sprintf("%s where w.public_id = @worker_id", listWorkersQuery)

rows, err := reader.Query(ctx, lookupQuery, []any{sql.Named("worker_id", id)})
if err != nil {
if errors.IsNotFoundError(err) {
return nil, nil
}
return nil, errors.Wrap(ctx, err, op)
}
w, err := wAgg.toWorker(ctx)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
defer rows.Close()
for rows.Next() {
var worker Worker
reader.ScanRows(context.Background(), rows, &worker)
return &worker, nil
}
return w, nil
return nil, nil
}

// ListWorkers will return a listing of Workers and honor the WithLimit option.
Expand All @@ -155,6 +156,8 @@ func (r *Repository) ListWorkers(ctx context.Context, scopeIds []string, opt ...
default:
newOpts = append(newOpts, WithLimit(r.defaultLimit))
}

// Note: these options below will be removed in a future PR for this llb
// handle the WithLiveness default
switch {
case opts.withLiveness != 0:
Expand Down Expand Up @@ -187,6 +190,7 @@ func ListWorkers(ctx context.Context, reader db.Reader, scopeIds []string, opt .
}

opts := GetOpts(opt...)
// Note: this option will be removed in a future PR for this llb
liveness := opts.withLiveness
if liveness == 0 {
liveness = DefaultLiveness
Expand All @@ -198,27 +202,30 @@ func ListWorkers(ctx context.Context, reader db.Reader, scopeIds []string, opt .
where = append(where, fmt.Sprintf("last_status_time > now() - interval '%d seconds'", uint32(liveness.Seconds())))
}
if len(scopeIds) > 0 {
where = append(where, "scope_id in (?)")
whereArgs = append(whereArgs, scopeIds)
where = append(where, "w.scope_id in (@scope_id)")
whereArgs = append(whereArgs, sql.Named("scope_id", strings.Join(scopeIds, ",")))
}

// Note: this option will be removed in a future PR for this llb
switch opts.withWorkerType {
case "":
case KmsWorkerType, PkiWorkerType:
where = append(where, "type = ?")
whereArgs = append(whereArgs, opts.withWorkerType.String())
where = append(where, "w.type = @type")
whereArgs = append(whereArgs, sql.Named("type", opts.withWorkerType.String()))
default:
return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unknown worker type %v", opts.withWorkerType))
}

// Note: this option will be removed in a future PR for this llb
if opts.withActiveWorkers {
where = append(where, "operational_state = ?")
whereArgs = append(whereArgs, ActiveOperationalState.String())
where = append(where, "w.operational_state = @operational_state")
whereArgs = append(whereArgs, sql.Named("operational_state", ActiveOperationalState))
}

// Note: this option will be removed in a future PR for this llb
if len(opts.withWorkerPool) > 0 {
where = append(where, "public_id in (?)")
whereArgs = append(whereArgs, opts.withWorkerPool)
whereString := fmt.Sprintf("w.public_id in ('%s')", strings.Join(opts.withWorkerPool, "','"))
where = append(where, whereString)
}

limit := db.DefaultLimit
Expand All @@ -227,25 +234,23 @@ func ListWorkers(ctx context.Context, reader db.Reader, scopeIds []string, opt .
limit = opts.withLimit
}

var wAggs []*workerAggregate
if err := reader.SearchWhere(
ctx,
&wAggs,
strings.Join(where, " and "),
whereArgs,
db.WithLimit(limit),
); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error searching for workers"))
query := fmt.Sprintf("%s where %s", listWorkersQuery, strings.Join(where, " and "))
if limit > 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
}

workers := make([]*Worker, 0, len(wAggs))
for _, a := range wAggs {
w, err := a.toWorker(ctx)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error converting workerAggregate to Worker"))
}
workers = append(workers, w)
var workers []*Worker
rows, err := reader.Query(ctx, query, whereArgs)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error searching for workers"))
}
defer rows.Close()
for rows.Next() {
var worker Worker
reader.ScanRows(context.Background(), rows, &worker)
workers = append(workers, &worker)
}

return workers, nil
}

Expand Down Expand Up @@ -366,15 +371,10 @@ func (r *Repository) UpsertWorkerStatus(ctx context.Context, worker *Worker, opt
}
}

wAgg := &workerAggregate{PublicId: workerClone.GetPublicId()}
if err := reader.LookupById(ctx, wAgg); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("error looking up worker aggregate"))
}
ret, err = wAgg.toWorker(ctx)
ret, err = lookupWorker(ctx, reader, workerClone.GetPublicId())
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("error converting worker aggregate to worker"))
return errors.Wrap(ctx, err, op, errors.WithMsg("error looking up worker"))
}

return nil
},
)
Expand Down Expand Up @@ -554,13 +554,10 @@ func (r *Repository) UpdateWorker(ctx context.Context, worker *Worker, version u
return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated")
}

wAgg = &workerAggregate{PublicId: worker.GetPublicId()}
if err := reader.LookupById(ctx, wAgg); err != nil {
ret, err = lookupWorker(ctx, reader, worker.GetPublicId())
if err != nil {
return errors.Wrap(ctx, err, op)
}
if ret, err = wAgg.toWorker(ctx); err != nil {
return err
}
ret.RemoteStorageStates, err = r.ListWorkerStorageBucketCredentialState(ctx, ret.GetPublicId())
if err != nil {
return err
Expand Down Expand Up @@ -726,7 +723,7 @@ func (r *Repository) AddWorkerTags(ctx context.Context, workerId string, workerV
return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("no worker found with public id %s", workerId))
}

newTags := append(worker.apiTags, tags...)
newTags := append(worker.ApiTags, tags...)
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error {
worker := worker.clone()
worker.PublicId = workerId
Expand Down
25 changes: 12 additions & 13 deletions internal/server/repository_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestLookupWorker(t *testing.T) {
got, err := repo.LookupWorker(ctx, w.GetPublicId())
require.NoError(t, err)
assert.Empty(t, cmp.Diff(w, got, protocmp.Transform()))
assert.Equal(t, uint32(3), got.ActiveConnectionCount())
assert.Equal(t, uint32(3), got.GetActiveConnectionCount())
assert.Equal(t, map[string][]string{
"key": {"val"},
}, got.CanonicalTags())
Expand All @@ -207,17 +207,6 @@ func TestLookupWorker(t *testing.T) {
require.NoError(t, err)
assert.Nil(t, got)
})
t.Run("db error", func(t *testing.T) {
conn, mock := db.TestSetupWithMock(t)
rw := db.New(conn)
mock.ExpectQuery(`SELECT`).WillReturnError(errors.New(context.Background(), errors.Internal, "test", "lookup-error"))
r, err := server.NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
got, err := r.LookupWorker(ctx, w.GetPublicId())
assert.NoError(t, mock.ExpectationsWereMet())
assert.Truef(t, errors.Match(errors.T(errors.Op("server.(Repository).LookupWorker")), err), "got error %v", err)
assert.Nil(t, got)
})
}

func TestUpsertWorkerStatus(t *testing.T) {
Expand Down Expand Up @@ -749,7 +738,17 @@ func TestListWorkers_WithWorkerPool(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
got, err := serversRepo.ListWorkers(ctx, []string{scope.Global.String()}, server.WithLiveness(-1), server.WithWorkerPool(tt.workerPool))
require.NoError(err)
assert.ElementsMatch(t, tt.want, got)
assert.Equal(t, len(tt.want), len(got))
found := 0
for _, w := range tt.want {
for _, g := range got {
if w.PublicId == g.PublicId {
assert.Equal(t, w.Worker, g.Worker)
found++
}
}
}
assert.Equal(t, len(tt.want), found)
})
}
}
Expand Down
43 changes: 21 additions & 22 deletions internal/server/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ func AttachWorkerIdToState(ctx context.Context, workerId string) (*structpb.Stru
// authorizing and establishing a session. It is owned by a scope.
type Worker struct {
*store.Worker

activeConnectionCount uint32 `gorm:"-"`
apiTags Tags
configTags Tags
ApiTags Tags `json:"api_tags" gorm:"->"`
ConfigTags Tags `json:"config_tags" gorm:"->"`
ActiveConnectionCount uint32 `gorm:"->"`

// inputTags is not specified to be api or config tags and is not intended
// to be read by clients. Since config tags and api tags are applied in
Expand Down Expand Up @@ -149,7 +148,7 @@ func NewWorker(scopeId string, opt ...Option) *Worker {
inputTags: opts.withWorkerTags,
}
if opts.withTestUseInputTagsAsApiTags {
worker.apiTags = worker.inputTags
worker.ApiTags = worker.inputTags
}
return worker
}
Expand All @@ -167,16 +166,16 @@ func (w *Worker) clone() *Worker {
cWorker := &Worker{
Worker: cw.(*store.Worker),
}
if w.apiTags != nil {
cWorker.apiTags = make([]*Tag, 0, len(w.apiTags))
for _, t := range w.apiTags {
cWorker.apiTags = append(cWorker.apiTags, &Tag{Key: t.Key, Value: t.Value})
if w.ApiTags != nil {
cWorker.ApiTags = make([]*Tag, 0, len(w.ApiTags))
for _, t := range w.ApiTags {
cWorker.ApiTags = append(cWorker.ApiTags, &Tag{Key: t.Key, Value: t.Value})
}
}
if w.configTags != nil {
cWorker.configTags = make([]*Tag, 0, len(w.configTags))
for _, t := range w.configTags {
cWorker.configTags = append(cWorker.configTags, &Tag{Key: t.Key, Value: t.Value})
if w.ConfigTags != nil {
cWorker.ConfigTags = make([]*Tag, 0, len(w.ConfigTags))
for _, t := range w.ConfigTags {
cWorker.ConfigTags = append(cWorker.ConfigTags, &Tag{Key: t.Key, Value: t.Value})
}
}
if w.inputTags != nil {
Expand All @@ -190,19 +189,19 @@ func (w *Worker) clone() *Worker {

// ActiveConnectionCount is the current number of sessions this worker is handling
// according to the controllers.
func (w *Worker) ActiveConnectionCount() uint32 {
return w.activeConnectionCount
func (w *Worker) GetActiveConnectionCount() uint32 {
return w.ActiveConnectionCount
}

// CanonicalTags is the deduplicated set of tags contained on both the resource
// set over the API as well as the tags reported by the worker itself. This
// function is guaranteed to return a non-nil map.
func (w *Worker) CanonicalTags(opt ...Option) map[string][]string {
dedupedTags := make(map[Tag]struct{})
for _, t := range w.apiTags {
for _, t := range w.ApiTags {
dedupedTags[*t] = struct{}{}
}
for _, t := range w.configTags {
for _, t := range w.ConfigTags {
dedupedTags[*t] = struct{}{}
}
tags := make(map[string][]string)
Expand All @@ -216,7 +215,7 @@ func (w *Worker) CanonicalTags(opt ...Option) map[string][]string {
// the worker daemon's configuration file.
func (w *Worker) GetConfigTags() map[string][]string {
tags := make(map[string][]string)
for _, t := range w.configTags {
for _, t := range w.ConfigTags {
tags[t.Key] = append(tags[t.Key], t.Value)
}
return tags
Expand All @@ -225,7 +224,7 @@ func (w *Worker) GetConfigTags() map[string][]string {
// GetApiTags returns the api tags which have been set for this worker.
func (w *Worker) GetApiTags() map[string][]string {
tags := make(map[string][]string)
for _, t := range w.apiTags {
for _, t := range w.ApiTags {
tags[t.Key] = append(tags[t.Key], t.Value)
}
return tags
Expand Down Expand Up @@ -286,10 +285,10 @@ func (a *workerAggregate) toWorker(ctx context.Context) (*Worker, error) {
OperationalState: a.OperationalState,
LocalStorageState: a.LocalStorageState,
},
activeConnectionCount: a.ActiveConnectionCount,
ActiveConnectionCount: a.ActiveConnectionCount,
RemoteStorageStates: map[string]*plugin.StorageBucketCredentialState{},
apiTags: a.ApiTags,
configTags: a.WorkerConfigTags,
ApiTags: a.ApiTags,
ConfigTags: a.WorkerConfigTags,
}

return worker, nil
Expand Down
Loading

0 comments on commit 1de4022

Please sign in to comment.