Skip to content

Commit

Permalink
vault: add retries to read-only calls
Browse files Browse the repository at this point in the history
  • Loading branch information
gmichelo committed Nov 20, 2024
1 parent 86d9089 commit c7956b7
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 106 deletions.
51 changes: 2 additions & 49 deletions internal/fnapi/fnapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"syscall"
"time"

"github.com/cenkalti/backoff"
"github.com/spf13/pflag"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"namespacelabs.dev/foundation/internal/console"
"namespacelabs.dev/foundation/internal/fnerrors"
"namespacelabs.dev/foundation/internal/versions"
"namespacelabs.dev/foundation/std/tryhard"
"namespacelabs.dev/go-ids"
"namespacelabs.dev/integrations/nsc/apienv"
)
Expand Down Expand Up @@ -185,7 +181,7 @@ func (c Call[RequestT]) Do(ctx context.Context, request RequestT, resolveEndpoin

fmt.Fprintf(console.Debug(ctx), "[%s] Body: %s\n", tid, reqDebugBytes)

return callSideEffectFree(ctx, c.Retryable, func(ctx context.Context) error {
return tryhard.CallSideEffectFree0(ctx, c.Retryable, func(ctx context.Context) error {
t := time.Now()
url := endpoint + "/" + c.Method
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(reqBytes))
Expand Down Expand Up @@ -305,49 +301,6 @@ func (c Call[RequestT]) Do(ctx context.Context, request RequestT, resolveEndpoin
})
}

func callSideEffectFree(ctx context.Context, retryable bool, method func(context.Context) error) error {
if !retryable {
return method(ctx)
}

b := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 5 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Clock: backoff.SystemClock,
}

b.Reset()

span := trace.SpanFromContext(ctx)

return backoff.Retry(func() error {
if methodErr := method(ctx); methodErr != nil {
// grpc's ConnectionError have a Temporary() signature. If we, for example, write to
// a channel and that channel is gone, then grpc observes a ECONNRESET. And propagates
// it as a temporary error. It doesn't know though whether it's safe to retry, so it
// doesn't.
if temp, ok := methodErr.(interface{ Temporary() bool }); ok && temp.Temporary() {
span.RecordError(methodErr, trace.WithAttributes(attribute.Bool("grpc.temporary_error", true)))
return methodErr
}

var netErr *net.OpError
if errors.As(methodErr, &netErr) {
if errno, ok := netErr.Err.(syscall.Errno); ok && errno == syscall.ECONNRESET {
return methodErr // Retry
}
}

return backoff.Permanent(methodErr)
}

return nil
}, backoff.WithContext(b, ctx))
}

func handleGrpcStatus(url string, st *spb.Status) error {
switch st.Code {
case int32(codes.Unauthenticated):
Expand Down
13 changes: 8 additions & 5 deletions internal/providers/vault/approle.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"namespacelabs.dev/foundation/internal/fnerrors"
"namespacelabs.dev/foundation/std/cfg"
"namespacelabs.dev/foundation/std/tasks"
"namespacelabs.dev/foundation/std/tryhard"
"namespacelabs.dev/foundation/universe/vault"
)

Expand Down Expand Up @@ -43,11 +44,13 @@ func createSecretId(ctx context.Context, vaultClient *vaultclient.Client, vaultC
var err error
creds.RoleId, err = tasks.Return(ctx, tasks.Action("vault.read-role-id").Arg("name", cfg.GetName()),
func(ctx context.Context) (string, error) {
res, err := vaultClient.Auth.AppRoleReadRoleId(ctx, cfg.GetName(), wmp)
if err != nil {
return "", fnerrors.InvocationError("vault", "failed to read role id: %w", err)
}
return res.Data.RoleId, nil
return tryhard.CallSideEffectFree1(ctx, true, func(ctx context.Context) (string, error) {
res, err := vaultClient.Auth.AppRoleReadRoleId(ctx, cfg.GetName(), wmp)
if err != nil {
return "", fnerrors.InvocationError("vault", "failed to read role id: %w", err)
}
return res.Data.RoleId, nil
})
})
return err
})
Expand Down
49 changes: 26 additions & 23 deletions internal/providers/vault/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"namespacelabs.dev/foundation/internal/fnerrors"
"namespacelabs.dev/foundation/std/cfg"
"namespacelabs.dev/foundation/std/tasks"
"namespacelabs.dev/foundation/std/tryhard"
"namespacelabs.dev/foundation/universe/vault"
)

Expand Down Expand Up @@ -74,29 +75,31 @@ func certificateProvider(ctx context.Context, conf cfg.Configuration, secretId s
func issueCertificate(ctx context.Context, vaultClient *vaultclient.Client, pkiMount, pkiRole string, req certificateRequest) ([]byte, error) {
return tasks.Return(ctx, tasks.Action("vault.issue-certificate").Arg("pki-mount", pkiMount).Arg("pki-role", pkiRole).Arg("common-name", req.commonName),
func(ctx context.Context) ([]byte, error) {
issueResp, err := vaultClient.Secrets.PkiIssueWithRole(ctx, pkiRole,
schema.PkiIssueWithRoleRequest{
CommonName: req.commonName,
AltNames: strings.Join(req.sans, ","),
ExcludeCnFromSans: req.excludeCnFromSans,
IpSans: req.ipSans,
},
vaultclient.WithMountPath(pkiMount),
)
if err != nil {
return nil, fnerrors.InvocationError("vault", "failed to issue a certificate: %w", err)
}

data, err := vault.TlsBundle{
PrivateKeyPem: issueResp.Data.PrivateKey,
CertificatePem: issueResp.Data.Certificate,
CaChainPem: issueResp.Data.CaChain,
}.Encode()
if err != nil {
return nil, fnerrors.BadDataError("failed to serialize certificate data: %w", err)
}

return data, nil
return tryhard.CallSideEffectFree1(ctx, true, func(ctx context.Context) ([]byte, error) {
issueResp, err := vaultClient.Secrets.PkiIssueWithRole(ctx, pkiRole,
schema.PkiIssueWithRoleRequest{
CommonName: req.commonName,
AltNames: strings.Join(req.sans, ","),
ExcludeCnFromSans: req.excludeCnFromSans,
IpSans: req.ipSans,
},
vaultclient.WithMountPath(pkiMount),
)
if err != nil {
return nil, fnerrors.InvocationError("vault", "failed to issue a certificate: %w", err)
}

data, err := vault.TlsBundle{
PrivateKeyPem: issueResp.Data.PrivateKey,
CertificatePem: issueResp.Data.Certificate,
CaChainPem: issueResp.Data.CaChain,
}.Encode()
if err != nil {
return nil, fnerrors.BadDataError("failed to serialize certificate data: %w", err)
}

return data, nil
})
},
)
}
61 changes: 32 additions & 29 deletions internal/providers/vault/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"namespacelabs.dev/foundation/internal/fnerrors"
"namespacelabs.dev/foundation/std/cfg"
"namespacelabs.dev/foundation/std/tasks"
"namespacelabs.dev/foundation/std/tryhard"
"namespacelabs.dev/foundation/universe/vault"
)

Expand All @@ -29,34 +30,36 @@ func secretProvider(ctx context.Context, conf cfg.Configuration, secretId secret

return tasks.Return(ctx, tasks.Action("vault.read-secret").Arg("ref", secretRef),
func(ctx context.Context) ([]byte, error) {
secretPkg, secretKey, found := strings.Cut(secretRef, ":")
if !found {
return nil, fnerrors.BadInputError("invalid vault secret reference: expects secret refernece in format '<mount>/<path>:<key>'")
}

secretMount, secretPath, found := strings.Cut(secretPkg, "/")
if !found {
return nil, fnerrors.BadInputError("invalid vault secret package: expects secret package in format '<mount>/<path>'")
}
vaultClient, err := login(ctx, vaultConfig)
if err != nil {
return nil, err
}

secretResp, err := vaultClient.Secrets.KvV2Read(ctx, secretPath, vaultclient.WithMountPath(secretMount))
if err != nil {
return nil, fnerrors.InvocationError("vault", "failed to read a secret: %w", err)
}

if secretResp.Data.Data == nil {
return nil, fnerrors.InvocationError("vault", "secret response contained no data")
}

secret, ok := secretResp.Data.Data[secretKey].(string)
if !ok {
return nil, fnerrors.InvocationError("vault", "response data contained no expected secret %q", secretKey)
}

return []byte(secret), nil
return tryhard.CallSideEffectFree1(ctx, true, func(ctx context.Context) ([]byte, error) {
secretPkg, secretKey, found := strings.Cut(secretRef, ":")
if !found {
return nil, fnerrors.BadInputError("invalid vault secret reference: expects secret refernece in format '<mount>/<path>:<key>'")
}

secretMount, secretPath, found := strings.Cut(secretPkg, "/")
if !found {
return nil, fnerrors.BadInputError("invalid vault secret package: expects secret package in format '<mount>/<path>'")
}
vaultClient, err := login(ctx, vaultConfig)
if err != nil {
return nil, err
}

secretResp, err := vaultClient.Secrets.KvV2Read(ctx, secretPath, vaultclient.WithMountPath(secretMount))
if err != nil {
return nil, fnerrors.InvocationError("vault", "failed to read a secret: %w", err)
}

if secretResp.Data.Data == nil {
return nil, fnerrors.InvocationError("vault", "secret response contained no data")
}

secret, ok := secretResp.Data.Data[secretKey].(string)
if !ok {
return nil, fnerrors.InvocationError("vault", "response data contained no expected secret %q", secretKey)
}

return []byte(secret), nil
})
})
}
69 changes: 69 additions & 0 deletions std/tryhard/tryhard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package tryhard

import (
"context"
"errors"
"net"
"syscall"
"time"

"github.com/cenkalti/backoff"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

func CallSideEffectFree0(ctx context.Context, retryable bool, method func(context.Context) error) error {
_, err := CallSideEffectFree1(ctx, retryable, func(ctx context.Context) (any, error) {
return nil, method(ctx)
})

return err
}

func CallSideEffectFree1[T any](ctx context.Context, retryable bool, method func(context.Context) (T, error)) (T, error) {
if !retryable {
return method(ctx)
}

b := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 5 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Clock: backoff.SystemClock,
}

b.Reset()

span := trace.SpanFromContext(ctx)

var finalRet T
err := backoff.Retry(func() error {
ret, methodErr := method(ctx)
if methodErr != nil {
// grpc's ConnectionError have a Temporary() signature. If we, for example, write to
// a channel and that channel is gone, then grpc observes a ECONNRESET. And propagates
// it as a temporary error. It doesn't know though whether it's safe to retry, so it
// doesn't.
if temp, ok := methodErr.(interface{ Temporary() bool }); ok && temp.Temporary() {
span.RecordError(methodErr, trace.WithAttributes(attribute.Bool("grpc.temporary_error", true)))
return methodErr
}

var netErr *net.OpError
if errors.As(methodErr, &netErr) {
if errno, ok := netErr.Err.(syscall.Errno); ok && errno == syscall.ECONNRESET {
return methodErr // Retry
}
}

return backoff.Permanent(methodErr)
}

finalRet = ret
return nil
}, backoff.WithContext(b, ctx))

return finalRet, err
}

0 comments on commit c7956b7

Please sign in to comment.