Skip to content

Commit

Permalink
Testing retry and timeout for signing ops (#366)
Browse files Browse the repository at this point in the history
* Testing retry and timeout for signing ops

* bugfixes

* Adjust use of context for retry/timeout

---------

Co-authored-by: Leland Garofalo <[email protected]>
  • Loading branch information
lgarofalo and Leland Garofalo authored Jul 25, 2023
1 parent 83f280f commit 27c8c29
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 7 deletions.
13 changes: 13 additions & 0 deletions cmd/gokeyless/gokeyless.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ type Config struct {
TracingEnabled bool `yaml:"tracing_enabled" mapstructure:"tracing_enabled"`
TracingAddress string `yaml:"tracing_address" mapstructure:"tracing_address"`
TracingSampleRate float64 `yaml:"tracing_sample_rate" mapstructure:"tracing_sample_rate"` // between 0 and 1

SignTimeout string `yaml:"sign_timeout" mapstructure:"sign_timeout"`
SignRetryCount int `yaml:"sign_retry_count" mapstructure:"sign_retry_count"`
}

// PrivateKeyStoreConfig defines a key store.
Expand Down Expand Up @@ -309,6 +312,16 @@ func runMain() error {
}

cfg := server.DefaultServeConfig()
if config.SignTimeout != "" {
signTimeoutDuration, err := time.ParseDuration(config.SignTimeout)
if err != nil {
log.Fatalf("failed to parse signTimeout: %s", err)
}
cfg = cfg.WithSignTimeout(signTimeoutDuration)
}
if config.SignRetryCount > 0 {
cfg = cfg.WithSignRetryCount(config.SignRetryCount)
}
s, err := server.NewServerFromFile(cfg, config.CertFile, config.KeyFile, config.CACertFile)
if err != nil {
return fmt.Errorf("cannot start server: %w", err)
Expand Down
81 changes: 75 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ type Server struct {
listeners map[net.Listener]map[net.Conn]struct{}
shutdown bool
mtx sync.Mutex

signTimeout time.Duration
signRetryCount int
}

// NewServer prepares a TLS server capable of receiving connections from keyless clients.
Expand All @@ -73,6 +76,8 @@ func NewServer(config *ServeConfig, cert tls.Certificate, keylessCA *x509.CertPo
dispatcher: rpc.NewServer(),
limitedDispatcher: rpc.NewServer(),
listeners: make(map[net.Listener]map[net.Conn]struct{}),
signTimeout: config.signTimeout,
signRetryCount: config.signRetryCount,
}

return s, nil
Expand Down Expand Up @@ -448,19 +453,56 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

signSpan, _ := opentracing.StartSpanFromContext(ctx, "execute.Sign")
signSpan, ctx := opentracing.StartSpanFromContext(ctx, "execute.Sign")
defer signSpan.Finish()
var sig []byte
sig, err = key.Sign(rand.Reader, pkt.Operation.Payload, opts)
if err != nil {
tracing.LogError(span, err)
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
// By default, we only try the request once, unless retry count is configured
for attempts := 1 + s.signRetryCount; attempts > 0; attempts-- {
var err error
// If signTimeout is not set, the value will be zero
if s.signTimeout == 0 {
sig, err = key.Sign(rand.Reader, pkt.Operation.Payload, opts)
} else {
ch := make(chan signWithTimeoutWrapper, 1)
ctxTimeout, cancel := context.WithTimeout(ctx, s.signTimeout)
defer cancel()

go signWithTimeout(ctxTimeout, ch, key, rand.Reader, pkt.Operation.Payload, opts)
select {
case <-ctxTimeout.Done():
sig = nil
err = ctxTimeout.Err()
case result := <-ch:
sig = result.sig
err = result.error
}
}
if err != nil {
if attempts > 1 {
log.Debugf("Connection %v: failed sign attempt: %s, %d attempt(s) left", connName, err, attempts-1)
continue
} else {
tracing.LogError(span, err)
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
break
}

return makeRespondResponse(pkt, sig)
}

type signWithTimeoutWrapper struct {
sig []byte
error error
}

func signWithTimeout(ctx context.Context, ch chan signWithTimeoutWrapper, key crypto.Signer, rand io.Reader, digest []byte, opts crypto.SignerOpts) {
sig, err := key.Sign(rand, digest, opts)
ch <- signWithTimeoutWrapper{sig, err}
}

func (s *Server) limitedDo(pkt *protocol.Packet, connName string) response {
spanCtx, err := tracing.SpanContextFromBinary(pkt.Operation.JaegerSpan)
if err != nil {
Expand Down Expand Up @@ -697,6 +739,8 @@ type ServeConfig struct {
tcpTimeout, unixTimeout time.Duration
isLimited func(state tls.ConnectionState) (bool, error)
customOpFunc CustomOpFunction
signTimeout time.Duration
signRetryCount int
}

const (
Expand All @@ -718,6 +762,8 @@ func DefaultServeConfig() *ServeConfig {
unixTimeout: defaultUnixTimeout,
maxConnPendingRequests: 1024,
isLimited: func(state tls.ConnectionState) (bool, error) { return false, nil },
signTimeout: 0,
signRetryCount: 0,
}
}

Expand Down Expand Up @@ -757,6 +803,29 @@ func (s *ServeConfig) WithIsLimited(f func(state tls.ConnectionState) (bool, err
return s
}

// WithSignTimeout specifies the sign operation timeout. This timeout is used to enforce a
// max execution time for a single sign operation
func (s *ServeConfig) WithSignTimeout(timeout time.Duration) *ServeConfig {
s.signTimeout = timeout
return s
}

// SignTimeout returns the sign operation timeout
func (s *ServeConfig) SignTimeout() time.Duration {
return s.signTimeout
}

// WithSignRetryCount specifics a number of retries to allow for failed sign operations
func (s *ServeConfig) WithSignRetryCount(signRetryCount int) *ServeConfig {
s.signRetryCount = signRetryCount
return s
}

// SignRetryCount returns the count of retries allowed for sign operations
func (s *ServeConfig) SignRetryCount() int {
return s.signRetryCount
}

// CustomOpFunction is the signature for custom opcode functions.
//
// If it returns a non-nil error which implements protocol.Error, the server
Expand Down
11 changes: 10 additions & 1 deletion tests/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ type IntegrationTestSuite struct {
ecdsaKey *client.PrivateKey
ed25519Key *client.PrivateKey
remote client.Remote

retryCount int
timeout time.Duration
}

func fixedCurrentTime() time.Time {
Expand Down Expand Up @@ -148,6 +151,11 @@ func (s *IntegrationTestSuite) NewRemoteSignerByPubKeyFile(filepath string) (cry
func TestSuite(t *testing.T) {
s := &IntegrationTestSuite{}
suite.Run(t, s)
s2 := &IntegrationTestSuite{
timeout: time.Second,
retryCount: 3,
}
suite.Run(t, s2)
}

// SetupTest sets up a compatible server and client for use by tests.
Expand All @@ -160,7 +168,8 @@ func (s *IntegrationTestSuite) SetupTest() {
atomic.StoreUint32(&client.TestDisableConnectionPool, 1)

var err error
s.server, err = server.NewServerFromFile(nil, serverCert, serverKey, keylessCA)
config := server.DefaultServeConfig().WithSignTimeout(s.timeout).WithSignRetryCount(s.retryCount)
s.server, err = server.NewServerFromFile(config, serverCert, serverKey, keylessCA)
require.NoError(err)
s.server.TLSConfig().Time = fixedCurrentTime

Expand Down

0 comments on commit 27c8c29

Please sign in to comment.