From 052a08ab6da9e37e5addcd52eea212f934e5dedc Mon Sep 17 00:00:00 2001 From: Hayden Blauzvern Date: Tue, 14 Jan 2025 20:04:41 -0600 Subject: [PATCH] Support per-shard signing keys This change enables key rotation with a per-shard signing key configuration. The LogRanges structure now holds both active and inactive shards, with the LogRange structure containing a signer, encoded public key and log ID based on the public key. This change is backwards compatible. If no signing configuration is specified, the active shard signing configuration is used for all shards. Minor change: Standardized log ID vs tree ID, where the former is the pubkey hash and the latter is the ID for the Trillian tree. Signed-off-by: Hayden Blauzvern --- pkg/api/api.go | 55 ++---- pkg/api/entries.go | 35 ++-- pkg/api/public_key.go | 2 +- pkg/api/tlog.go | 22 ++- pkg/sharding/log_index.go | 4 +- pkg/sharding/log_index_test.go | 8 +- pkg/sharding/ranges.go | 160 +++++++++++------ pkg/sharding/ranges_test.go | 306 ++++++++++++++++++++++++++++----- pkg/signer/signer.go | 13 ++ pkg/signer/signer_test.go | 24 +++ pkg/signer/tink.go | 4 + tests/sharding-e2e-test.sh | 17 +- 12 files changed, 468 insertions(+), 182 deletions(-) create mode 100644 pkg/signer/signer_test.go diff --git a/pkg/api/api.go b/pkg/api/api.go index ab85437ae..5b267e3c6 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -17,10 +17,8 @@ package api import ( "context" - "crypto/sha256" "crypto/tls" "crypto/x509" - "encoding/hex" "fmt" "os" "path/filepath" @@ -42,9 +40,6 @@ import ( "github.com/sigstore/rekor/pkg/storage" "github.com/sigstore/rekor/pkg/trillianclient" "github.com/sigstore/rekor/pkg/witness" - "github.com/sigstore/sigstore/pkg/cryptoutils" - "github.com/sigstore/sigstore/pkg/signature" - "github.com/sigstore/sigstore/pkg/signature/options" _ "github.com/sigstore/rekor/pkg/pubsub/gcp" // Load GCP pubsub implementation ) @@ -92,12 +87,9 @@ func dial(rpcServer string) (*grpc.ClientConn, error) { } type API struct { - logClient trillian.TrillianLogClient - logID int64 - logRanges sharding.LogRanges - pubkey string // PEM encoded public key - pubkeyHash string // SHA256 hash of DER-encoded public key - signer signature.Signer + logClient trillian.TrillianLogClient + treeID int64 + logRanges sharding.LogRanges // stops checkpoint publishing checkpointPublishCancel context.CancelFunc // Publishes notifications when new entries are added to the log. May be @@ -117,12 +109,6 @@ func NewAPI(treeID uint) (*API, error) { logAdminClient := trillian.NewTrillianAdminClient(tConn) logClient := trillian.NewTrillianLogClient(tConn) - shardingConfig := viper.GetString("trillian_log_server.sharding_config") - ranges, err := sharding.NewLogRanges(ctx, logClient, shardingConfig, treeID) - if err != nil { - return nil, fmt.Errorf("unable get sharding details from sharding config: %w", err) - } - tid := int64(treeID) if tid == 0 { log.Logger.Info("No tree ID specified, attempting to create a new tree") @@ -133,27 +119,18 @@ func NewAPI(treeID uint) (*API, error) { tid = t.TreeId } log.Logger.Infof("Starting Rekor server with active tree %v", tid) - ranges.SetActive(tid) - rekorSigner, err := signer.New(ctx, viper.GetString("rekor_server.signer"), - viper.GetString("rekor_server.signer-passwd"), - viper.GetString("rekor_server.tink_kek_uri"), - viper.GetString("rekor_server.tink_keyset_path"), - ) - if err != nil { - return nil, fmt.Errorf("getting new signer: %w", err) - } - pk, err := rekorSigner.PublicKey(options.WithContext(ctx)) - if err != nil { - return nil, fmt.Errorf("getting public key: %w", err) + shardingConfig := viper.GetString("trillian_log_server.sharding_config") + signingConfig := signer.SigningConfig{ + SigningSchemeOrKeyPath: viper.GetString("rekor_server.signer"), + FileSignerPassword: viper.GetString("rekor_server.signer-passwd"), + TinkKEKURI: viper.GetString("rekor_server.tink_kek_uri"), + TinkKeysetPath: viper.GetString("rekor_server.tink_keyset_path"), } - b, err := x509.MarshalPKIXPublicKey(pk) + ranges, err := sharding.NewLogRanges(ctx, logClient, shardingConfig, tid, signingConfig) if err != nil { - return nil, fmt.Errorf("marshalling public key: %w", err) + return nil, fmt.Errorf("unable get sharding details from sharding config: %w", err) } - pubkeyHashBytes := sha256.Sum256(b) - - pubkey := cryptoutils.PEMEncode(cryptoutils.PublicKeyPEMType, b) var newEntryPublisher pubsub.Publisher if p := viper.GetString("rekor_server.new_entry_publisher"); p != "" { @@ -170,12 +147,8 @@ func NewAPI(treeID uint) (*API, error) { return &API{ // Transparency Log Stuff logClient: logClient, - logID: tid, + treeID: tid, logRanges: ranges, - // Signing/verifying fields - pubkey: string(pubkey), - pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]), - signer: rekorSigner, // Utility functionality not required for operation of the core service newEntryPublisher: newEntryPublisher, }, nil @@ -212,8 +185,8 @@ func ConfigureAPI(treeID uint) { if viper.GetBool("enable_stable_checkpoint") { redisClient = NewRedisClient() - checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(), - viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount) + checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.GetActive().TreeID, + viper.GetString("rekor_server.hostname"), api.logRanges.GetActive().Signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount) // create context to cancel goroutine on server shutdown ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/api/entries.go b/pkg/api/entries.go index e54cf0bad..515f835db 100644 --- a/pkg/api/entries.go +++ b/pkg/api/entries.go @@ -74,7 +74,7 @@ func signEntry(ctx context.Context, signer signature.Signer, entry models.LogEnt } // logEntryFromLeaf creates a signed LogEntry struct from trillian structs -func logEntryFromLeaf(ctx context.Context, signer signature.Signer, _ trillianclient.TrillianClient, leaf *trillian.LogLeaf, +func logEntryFromLeaf(ctx context.Context, _ trillianclient.TrillianClient, leaf *trillian.LogLeaf, signedLogRoot *trillian.SignedLogRoot, proof *trillian.Proof, tid int64, ranges sharding.LogRanges) (models.LogEntry, error) { log.ContextLogger(ctx).Debugf("log entry from leaf %d", leaf.GetLeafIndex()) @@ -88,19 +88,24 @@ func logEntryFromLeaf(ctx context.Context, signer signature.Signer, _ trilliancl } virtualIndex := sharding.VirtualLogIndex(leaf.GetLeafIndex(), tid, ranges) + logRange, err := ranges.GetLogRangeByTreeID(tid) + if err != nil { + return nil, err + } + logEntryAnon := models.LogEntryAnon{ - LogID: swag.String(api.pubkeyHash), + LogID: swag.String(logRange.PemPubKey), LogIndex: &virtualIndex, Body: leaf.LeafValue, IntegratedTime: swag.Int64(leaf.IntegrateTimestamp.AsTime().Unix()), } - signature, err := signEntry(ctx, signer, logEntryAnon) + signature, err := signEntry(ctx, logRange.Signer, logEntryAnon) if err != nil { return nil, fmt.Errorf("signing entry error: %w", err) } - scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, api.signer) + scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, logRange.Signer) if err != nil { return nil, err } @@ -194,7 +199,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry) } - tc := trillianclient.NewTrillianClient(ctx, api.logClient, api.logID) + tc := trillianclient.NewTrillianClient(ctx, api.logClient, api.treeID) resp := tc.AddLeaf(leaf) // this represents overall GRPC response state (not the results of insertion into the log) @@ -209,7 +214,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl case int32(code.Code_OK): case int32(code.Code_ALREADY_EXISTS), int32(code.Code_FAILED_PRECONDITION): existingUUID := hex.EncodeToString(rfc6962.DefaultHasher.HashLeaf(leaf)) - activeTree := fmt.Sprintf("%x", api.logID) + activeTree := fmt.Sprintf("%x", api.treeID) entryIDstruct, err := sharding.CreateEntryIDFromParts(activeTree, existingUUID) if err != nil { err := fmt.Errorf("error creating EntryID from active treeID %v and uuid %v: %w", activeTree, existingUUID, err) @@ -230,7 +235,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl queuedLeaf := resp.GetAddResult.QueuedLeaf.Leaf uuid := hex.EncodeToString(queuedLeaf.GetMerkleLeafHash()) - activeTree := fmt.Sprintf("%x", api.logID) + activeTree := fmt.Sprintf("%x", api.treeID) entryIDstruct, err := sharding.CreateEntryIDFromParts(activeTree, uuid) if err != nil { err := fmt.Errorf("error creating EntryID from active treeID %v and uuid %v: %w", activeTree, uuid, err) @@ -239,9 +244,9 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl entryID := entryIDstruct.ReturnEntryIDString() // The log index should be the virtual log index across all shards - virtualIndex := sharding.VirtualLogIndex(queuedLeaf.LeafIndex, api.logRanges.ActiveTreeID(), api.logRanges) + virtualIndex := sharding.VirtualLogIndex(queuedLeaf.LeafIndex, api.logRanges.GetActive().TreeID, api.logRanges) logEntryAnon := models.LogEntryAnon{ - LogID: swag.String(api.pubkeyHash), + LogID: swag.String(api.logRanges.GetActive().LogID), LogIndex: swag.Int64(virtualIndex), Body: queuedLeaf.GetLeafValue(), IntegratedTime: swag.Int64(queuedLeaf.IntegrateTimestamp.AsTime().Unix()), @@ -286,7 +291,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl } } - signature, err := signEntry(ctx, api.signer, logEntryAnon) + signature, err := signEntry(ctx, api.logRanges.GetActive().Signer, logEntryAnon) if err != nil { return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("signing entry error: %w", err), signingError) } @@ -300,7 +305,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl hashes = append(hashes, hex.EncodeToString(hash)) } - scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), api.logID, root.TreeSize, root.RootHash, api.signer) + scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), api.treeID, root.TreeSize, root.RootHash, api.logRanges.GetActive().Signer) if err != nil { return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError) } @@ -511,7 +516,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo continue } tcs := trillianclient.NewTrillianClient(httpReqCtx, api.logClient, shard) - logEntry, err := logEntryFromLeaf(httpReqCtx, api.signer, tcs, leafResp.Leaf, leafResp.SignedLogRoot, leafResp.Proof, shard, api.logRanges) + logEntry, err := logEntryFromLeaf(httpReqCtx, tcs, leafResp.Leaf, leafResp.SignedLogRoot, leafResp.Proof, shard, api.logRanges) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error()) } @@ -558,7 +563,7 @@ func retrieveLogEntryByIndex(ctx context.Context, logIndex int) (models.LogEntry return models.LogEntry{}, ErrNotFound } - return logEntryFromLeaf(ctx, api.signer, tc, leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges) + return logEntryFromLeaf(ctx, tc, leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges) } // Retrieve a Log Entry @@ -580,7 +585,7 @@ func retrieveLogEntry(ctx context.Context, entryUUID string) (models.LogEntry, e // If we got a UUID instead of an EntryID, search all shards if errors.Is(err, sharding.ErrPlainUUID) { - trees := []sharding.LogRange{{TreeID: api.logRanges.ActiveTreeID()}} + trees := []sharding.LogRange{api.logRanges.GetActive()} trees = append(trees, api.logRanges.GetInactive()...) for _, t := range trees { @@ -623,7 +628,7 @@ func retrieveUUIDFromTree(ctx context.Context, uuid string, tid int64) (models.L return models.LogEntry{}, err } - logEntry, err := logEntryFromLeaf(ctx, api.signer, tc, result.Leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges) + logEntry, err := logEntryFromLeaf(ctx, tc, result.Leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges) if err != nil { return models.LogEntry{}, fmt.Errorf("could not create log entry from leaf: %w", err) } diff --git a/pkg/api/public_key.go b/pkg/api/public_key.go index b4ff91625..819fba9f9 100644 --- a/pkg/api/public_key.go +++ b/pkg/api/public_key.go @@ -27,7 +27,7 @@ import ( func GetPublicKeyHandler(params pubkey.GetPublicKeyParams) middleware.Responder { treeID := swag.StringValue(params.TreeID) - pk, err := api.logRanges.PublicKey(api.pubkey, treeID) + pk, err := api.logRanges.PublicKey(treeID) if err != nil { return handleRekorAPIError(params, http.StatusBadRequest, err, "") } diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go index 96e3be2bf..83bb4d434 100644 --- a/pkg/api/tlog.go +++ b/pkg/api/tlog.go @@ -33,20 +33,18 @@ import ( "github.com/sigstore/rekor/pkg/log" "github.com/sigstore/rekor/pkg/trillianclient" "github.com/sigstore/rekor/pkg/util" + "github.com/sigstore/sigstore/pkg/signature" ) // GetLogInfoHandler returns the current size of the tree and the STH func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { - tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.logID) + tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID) // for each inactive shard, get the loginfo var inactiveShards []*models.InactiveShardLogInfo for _, shard := range api.logRanges.GetInactive() { - if shard.TreeID == api.logRanges.ActiveTreeID() { - break - } // Get details for this inactive shard - is, err := inactiveShardLogInfo(params.HTTPRequest.Context(), shard.TreeID) + is, err := inactiveShardLogInfo(params.HTTPRequest.Context(), shard.TreeID, shard.Signer) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("inactive shard error: %w", err), unexpectedInactiveShardError) } @@ -55,7 +53,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { if swag.BoolValue(params.Stable) && redisClient != nil { // key is treeID/latest - key := fmt.Sprintf("%d/latest", api.logRanges.ActiveTreeID()) + key := fmt.Sprintf("%d/latest", api.logRanges.GetActive().TreeID) redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result() if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, @@ -79,7 +77,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { RootHash: stringPointer(hex.EncodeToString(checkpoint.Hash)), TreeSize: swag.Int64(int64(checkpoint.Size)), SignedTreeHead: stringPointer(string(decoded)), - TreeID: stringPointer(fmt.Sprintf("%d", api.logID)), + TreeID: stringPointer(fmt.Sprintf("%d", api.treeID)), InactiveShards: inactiveShards, } return tlog.NewGetLogInfoOK().WithPayload(&logInfo) @@ -100,7 +98,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { treeSize := int64(root.TreeSize) scBytes, err := util.CreateAndSignCheckpoint(params.HTTPRequest.Context(), - viper.GetString("rekor_server.hostname"), api.logRanges.ActiveTreeID(), root.TreeSize, root.RootHash, api.signer) + viper.GetString("rekor_server.hostname"), api.logRanges.GetActive().TreeID, root.TreeSize, root.RootHash, api.logRanges.GetActive().Signer) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError) } @@ -109,7 +107,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { RootHash: &hashString, TreeSize: &treeSize, SignedTreeHead: stringPointer(string(scBytes)), - TreeID: stringPointer(fmt.Sprintf("%d", api.logID)), + TreeID: stringPointer(fmt.Sprintf("%d", api.treeID)), InactiveShards: inactiveShards, } @@ -126,7 +124,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder { errMsg := fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize) return handleRekorAPIError(params, http.StatusBadRequest, fmt.Errorf("consistency proof: %s", errMsg), errMsg) } - tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.logID) + tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID) if treeID := swag.StringValue(params.TreeID); treeID != "" { id, err := strconv.Atoi(treeID) if err != nil { @@ -170,7 +168,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder { return tlog.NewGetLogProofOK().WithPayload(&consistencyProof) } -func inactiveShardLogInfo(ctx context.Context, tid int64) (*models.InactiveShardLogInfo, error) { +func inactiveShardLogInfo(ctx context.Context, tid int64, signer signature.Signer) (*models.InactiveShardLogInfo, error) { tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid) resp := tc.GetLatest(0) if resp.Status != codes.OK { @@ -186,7 +184,7 @@ func inactiveShardLogInfo(ctx context.Context, tid int64) (*models.InactiveShard hashString := hex.EncodeToString(root.RootHash) treeSize := int64(root.TreeSize) - scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, api.signer) + scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, signer) if err != nil { return nil, err } diff --git a/pkg/sharding/log_index.go b/pkg/sharding/log_index.go index dcdfc1085..e07d6033a 100644 --- a/pkg/sharding/log_index.go +++ b/pkg/sharding/log_index.go @@ -19,7 +19,7 @@ func VirtualLogIndex(leafIndex int64, tid int64, ranges LogRanges) int64 { // if we have no inactive ranges, we have just one log! return the leafIndex as is // as long as it matches the active tree ID if ranges.NoInactive() { - if ranges.GetActive() == tid { + if ranges.GetActive().TreeID == tid { return leafIndex } return -1 @@ -34,7 +34,7 @@ func VirtualLogIndex(leafIndex int64, tid int64, ranges LogRanges) int64 { } // If no TreeID in Inactive matches the tid, the virtual index should be the active tree - if ranges.GetActive() == tid { + if ranges.GetActive().TreeID == tid { return virtualIndex + leafIndex } diff --git a/pkg/sharding/log_index_test.go b/pkg/sharding/log_index_test.go index 039c4ef30..ba2fc5c1a 100644 --- a/pkg/sharding/log_index_test.go +++ b/pkg/sharding/log_index_test.go @@ -44,7 +44,7 @@ func TestVirtualLogIndex(t *testing.T) { TreeID: 100, TreeLength: 5, }}, - active: 300, + active: LogRange{TreeID: 300}, }, expectedIndex: 7, }, @@ -64,7 +64,7 @@ func TestVirtualLogIndex(t *testing.T) { TreeID: 300, TreeLength: 4, }}, - active: 400, + active: LogRange{TreeID: 400}, }, expectedIndex: 6, }, @@ -74,7 +74,7 @@ func TestVirtualLogIndex(t *testing.T) { leafIndex: 2, tid: 30, ranges: LogRanges{ - active: 30, + active: LogRange{TreeID: 30}, }, expectedIndex: 2, }, { @@ -82,7 +82,7 @@ func TestVirtualLogIndex(t *testing.T) { leafIndex: 2, tid: 4, ranges: LogRanges{ - active: 30, + active: LogRange{TreeID: 30}, }, expectedIndex: -1, }, diff --git a/pkg/sharding/ranges.go b/pkg/sharding/ranges.go index 8556d8027..fae1d0c98 100644 --- a/pkg/sharding/ranges.go +++ b/pkg/sharding/ranges.go @@ -17,7 +17,9 @@ package sharding import ( "context" - "encoding/base64" + "crypto/sha256" + "crypto/x509" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -28,50 +30,80 @@ import ( "github.com/google/trillian" "github.com/google/trillian/types" "github.com/sigstore/rekor/pkg/log" + "github.com/sigstore/rekor/pkg/signer" + "github.com/sigstore/sigstore/pkg/cryptoutils" + "github.com/sigstore/sigstore/pkg/signature" + "github.com/sigstore/sigstore/pkg/signature/options" "sigs.k8s.io/yaml" ) +// Active and inactive shards type LogRanges struct { + // inactive shards are listed from oldest to newest inactive Ranges - active int64 + active LogRange } type Ranges []LogRange +// LogRange represents a log or tree shard type LogRange struct { - TreeID int64 `json:"treeID" yaml:"treeID"` - TreeLength int64 `json:"treeLength" yaml:"treeLength"` - EncodedPublicKey string `json:"encodedPublicKey" yaml:"encodedPublicKey"` - decodedPublicKey string + TreeID int64 `json:"treeID" yaml:"treeID"` + TreeLength int64 `json:"treeLength" yaml:"treeLength"` // unused for active tree + SigningConfig signer.SigningConfig `json:"signingConfig" yaml:"signingConfig"` // if unset, assume same as active tree + Signer signature.Signer + PemPubKey string // PEM-encoded PKIX public key + LogID string // Hex-encoded SHA256 digest of PKIX-encoded public key } -func NewLogRanges(ctx context.Context, logClient trillian.TrillianLogClient, path string, treeID uint) (LogRanges, error) { - if path == "" { - log.Logger.Info("No config file specified, skipping init of logRange map") - return LogRanges{}, nil +func (l LogRange) String() string { + return fmt.Sprintf("{ TreeID: %v, TreeLength: %v, SigningConfig: %v, PemPubKey: %v, LogID: %v }", l.TreeID, l.TreeLength, l.SigningConfig, l.PemPubKey, l.LogID) +} + +// NewLogRanges initializes the active and any inactive log shards +func NewLogRanges(ctx context.Context, logClient trillian.TrillianLogClient, + inactiveShardsPath string, activeTreeID int64, signingConfig signer.SigningConfig) (LogRanges, error) { + if activeTreeID == 0 { + return LogRanges{}, errors.New("non-zero active tree ID required; please set the active tree ID via the `--trillian_log_server.tlog_id` flag") } - if treeID == 0 { - return LogRanges{}, errors.New("non-zero tlog_id required when passing in shard config filepath; please set the active tree ID via the `--trillian_log_server.tlog_id` flag") + + // Initialize active shard + activeLog, err := updateRange(ctx, logClient, LogRange{TreeID: activeTreeID, TreeLength: 0, SigningConfig: signingConfig}, true /*=active*/) + if err != nil { + return LogRanges{}, fmt.Errorf("creating range for active tree %d: %w", activeTreeID, err) } - // otherwise, try to read contents of the sharding config - ranges, err := logRangesFromPath(path) + log.Logger.Infof("Active log: %v", activeLog) + + if inactiveShardsPath == "" { + log.Logger.Info("No config file specified, no inactive shards") + return LogRanges{active: activeLog}, nil + } + + // Initialize inactive shards from inactive tree IDs + ranges, err := logRangesFromPath(inactiveShardsPath) if err != nil { return LogRanges{}, fmt.Errorf("log ranges from path: %w", err) } for i, r := range ranges { - r, err := updateRange(ctx, logClient, r) + // If no signing config is provided, use the active tree signing key + if r.SigningConfig.IsUnset() { + r.SigningConfig = signingConfig + } + r, err := updateRange(ctx, logClient, r, false /*=active*/) if err != nil { return LogRanges{}, fmt.Errorf("updating range for tree id %d: %w", r.TreeID, err) } ranges[i] = r } - log.Logger.Info("Ranges: %v", ranges) + log.Logger.Infof("Ranges: %v", ranges) + return LogRanges{ inactive: ranges, - active: int64(treeID), + active: activeLog, }, nil } +// logRangesFromPath unmarshals a shard config func logRangesFromPath(path string) (Ranges, error) { var ranges Ranges contents, err := os.ReadFile(path) @@ -93,9 +125,9 @@ func logRangesFromPath(path string) (Ranges, error) { } // updateRange fills in any missing information about the range -func updateRange(ctx context.Context, logClient trillian.TrillianLogClient, r LogRange) (LogRange, error) { - // If a tree length wasn't passed in, get it ourselves - if r.TreeLength == 0 { +func updateRange(ctx context.Context, logClient trillian.TrillianLogClient, r LogRange, active bool) (LogRange, error) { + // If a tree length wasn't passed in or if the shard is inactive, fetch the tree size + if r.TreeLength == 0 && !active { resp, err := logClient.GetLatestSignedLogRoot(ctx, &trillian.GetLatestSignedLogRootRequest{LogId: r.TreeID}) if err != nil { return LogRange{}, fmt.Errorf("getting signed log root for tree %d: %w", r.TreeID, err) @@ -106,14 +138,38 @@ func updateRange(ctx context.Context, logClient trillian.TrillianLogClient, r Lo } r.TreeLength = int64(root.TreeSize) } - // If a public key was provided, decode it - if r.EncodedPublicKey != "" { - decoded, err := base64.StdEncoding.DecodeString(r.EncodedPublicKey) - if err != nil { - return LogRange{}, err - } - r.decodedPublicKey = string(decoded) + + if r.SigningConfig.IsUnset() { + return LogRange{}, fmt.Errorf("signing config not set, unable to initialize shard signer") + } + + // Initialize shard signer + s, err := signer.New(ctx, r.SigningConfig.SigningSchemeOrKeyPath, r.SigningConfig.FileSignerPassword, + r.SigningConfig.TinkKEKURI, r.SigningConfig.TinkKeysetPath) + if err != nil { + return LogRange{}, err + } + r.Signer = s + + // Initialize public key + pubKey, err := s.PublicKey(options.WithContext(ctx)) + if err != nil { + return LogRange{}, err + } + pemPubKey, err := cryptoutils.MarshalPublicKeyToPEM(pubKey) + if err != nil { + return LogRange{}, err } + r.PemPubKey = string(pemPubKey) + + // Initialize log ID from public key + b, err := x509.MarshalPKIXPublicKey(pubKey) + if err != nil { + return LogRange{}, err + } + pubkeyHashBytes := sha256.Sum256(b) + r.LogID = hex.EncodeToString(pubkeyHashBytes[:]) + return r, nil } @@ -127,11 +183,7 @@ func (l *LogRanges) ResolveVirtualIndex(index int) (int64, int64) { } // If index not found in inactive trees, return the active tree - return l.active, int64(indexLeft) -} - -func (l *LogRanges) ActiveTreeID() int64 { - return l.active + return l.active.TreeID, int64(indexLeft) } func (l *LogRanges) NoInactive() bool { @@ -140,7 +192,7 @@ func (l *LogRanges) NoInactive() bool { // AllShards returns all shards, starting with the active shard and then the inactive shards func (l *LogRanges) AllShards() []int64 { - shards := []int64{l.ActiveTreeID()} + shards := []int64{l.GetActive().TreeID} for _, in := range l.GetInactive() { shards = append(shards, in.TreeID) } @@ -157,23 +209,27 @@ func (l *LogRanges) TotalInactiveLength() int64 { return total } -func (l *LogRanges) SetInactive(r []LogRange) { - l.inactive = r +// GetLogRangebyTreeID returns the active or inactive +// shard with the given tree ID +func (l *LogRanges) GetLogRangeByTreeID(treeID int64) (LogRange, error) { + if l.active.TreeID == treeID { + return l.active, nil + } + for _, i := range l.inactive { + if i.TreeID == treeID { + return i, nil + } + } + return LogRange{}, fmt.Errorf("no log range found for tree ID %d", treeID) } +// GetInactive returns all inactive shards func (l *LogRanges) GetInactive() []LogRange { return l.inactive } -func (l *LogRanges) AppendInactive(r LogRange) { - l.inactive = append(l.inactive, r) -} - -func (l *LogRanges) SetActive(i int64) { - l.active = i -} - -func (l *LogRanges) GetActive() int64 { +// GetActive returns the cative shard +func (l *LogRanges) GetActive() LogRange { return l.active } @@ -182,16 +238,16 @@ func (l *LogRanges) String() string { for _, r := range l.inactive { ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength)) } - ranges = append(ranges, fmt.Sprintf("active=%d", l.active)) + ranges = append(ranges, fmt.Sprintf("active=%d", l.active.TreeID)) return strings.Join(ranges, ",") } // PublicKey returns the associated public key for the given Tree ID // and returns the active public key by default -func (l *LogRanges) PublicKey(activePublicKey, treeID string) (string, error) { +func (l *LogRanges) PublicKey(treeID string) (string, error) { // if no tree ID is specified, assume the active tree if treeID == "" { - return activePublicKey, nil + return l.active.PemPubKey, nil } tid, err := strconv.Atoi(treeID) if err != nil { @@ -200,15 +256,11 @@ func (l *LogRanges) PublicKey(activePublicKey, treeID string) (string, error) { for _, i := range l.inactive { if int(i.TreeID) == tid { - if i.decodedPublicKey != "" { - return i.decodedPublicKey, nil - } - // assume the active public key if one wasn't provided - return activePublicKey, nil + return i.PemPubKey, nil } } - if tid == int(l.active) { - return activePublicKey, nil + if tid == int(l.GetActive().TreeID) { + return l.active.PemPubKey, nil } return "", fmt.Errorf("%d is not a valid tree ID and doesn't have an associated public key", tid) } diff --git a/pkg/sharding/ranges_test.go b/pkg/sharding/ranges_test.go index ab020c800..48d7c5f41 100644 --- a/pkg/sharding/ranges_test.go +++ b/pkg/sharding/ranges_test.go @@ -17,15 +17,27 @@ package sharding import ( "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/hex" "encoding/json" "errors" + "fmt" "os" "path/filepath" "reflect" + "strings" "testing" "github.com/golang/mock/gomock" "github.com/google/trillian/testonly" + "github.com/sigstore/rekor/pkg/signer" + "github.com/sigstore/sigstore/pkg/cryptoutils" + "github.com/sigstore/sigstore/pkg/signature" "github.com/google/trillian" "google.golang.org/grpc" @@ -33,42 +45,83 @@ import ( ) func TestNewLogRanges(t *testing.T) { - contents := ` + keyPath, ecdsaSigner, pemPubKey, logID := initializeSigner(t) + sc := signer.SigningConfig{SigningSchemeOrKeyPath: keyPath} + + // inactive shard with different key + keyPathI, ecdsaSignerI, pemPubKeyI, logIDI := initializeSigner(t) + scI := signer.SigningConfig{SigningSchemeOrKeyPath: keyPathI} + + contents := fmt.Sprintf(` - treeID: 0001 treeLength: 3 - encodedPublicKey: c2hhcmRpbmcK - treeID: 0002 - treeLength: 4` + treeLength: 4 +- treeID: 0003 + treeLength: 5 + signingConfig: + signingSchemeOrKeyPath: '%s'`, keyPathI) + fmt.Println(contents) file := filepath.Join(t.TempDir(), "sharding-config") if err := os.WriteFile(file, []byte(contents), 0o644); err != nil { t.Fatal(err) } - treeID := uint(45) + treeID := int64(45) expected := LogRanges{ inactive: []LogRange{ + // two inactive shards without signing config + // inherit config from active shard { - TreeID: 1, - TreeLength: 3, - EncodedPublicKey: "c2hhcmRpbmcK", - decodedPublicKey: "sharding\n", + TreeID: 1, + TreeLength: 3, + SigningConfig: sc, + Signer: ecdsaSigner, + PemPubKey: pemPubKey, + LogID: logID, + }, { + TreeID: 2, + TreeLength: 4, + SigningConfig: sc, + Signer: ecdsaSigner, + PemPubKey: pemPubKey, + LogID: logID, }, { - TreeID: 2, - TreeLength: 4, + // inactive shard with custom signing config + TreeID: 3, + TreeLength: 5, + SigningConfig: scI, + Signer: ecdsaSignerI, + PemPubKey: pemPubKeyI, + LogID: logIDI, }, }, - active: int64(45), + active: LogRange{ + TreeID: 45, + TreeLength: 0, // unset + SigningConfig: sc, + Signer: ecdsaSigner, + PemPubKey: pemPubKey, + LogID: logID, + }, } ctx := context.Background() tc := trillian.NewTrillianLogClient(&grpc.ClientConn{}) - got, err := NewLogRanges(ctx, tc, file, treeID) + got, err := NewLogRanges(ctx, tc, file, treeID, sc) if err != nil { t.Fatal(err) } - if expected.ActiveTreeID() != got.ActiveTreeID() { - t.Fatalf("expected tree id %d got %d", expected.ActiveTreeID(), got.ActiveTreeID()) + if expected.GetActive().TreeID != got.GetActive().TreeID { + t.Fatalf("expected tree id %d got %d", expected.GetActive().TreeID, got.GetActive().TreeID) + } + for i, expected := range expected.GetInactive() { + got := got.GetInactive()[i] + logRangeEqual(t, expected, got) } - if !reflect.DeepEqual(expected.GetInactive(), got.GetInactive()) { - t.Fatalf("expected %v got %v", expected.GetInactive(), got.GetInactive()) + + // Failure: Tree ID = 0 + _, err = NewLogRanges(ctx, tc, file, 0, sc) + if err == nil || !strings.Contains(err.Error(), "non-zero active tree ID required") { + t.Fatal("expected error initializing log ranges with 0 tree ID") } } @@ -79,7 +132,7 @@ func TestLogRanges_ResolveVirtualIndex(t *testing.T) { {TreeID: 2, TreeLength: 1}, {TreeID: 3, TreeLength: 100}, }, - active: 4, + active: LogRange{TreeID: 4}, } for _, tt := range []struct { @@ -112,21 +165,66 @@ func TestLogRanges_ResolveVirtualIndex(t *testing.T) { } } -func TestPublicKey(t *testing.T) { +func TestLogRanges_GetLogRangeByTreeID(t *testing.T) { + lrs := LogRanges{ + inactive: []LogRange{ + {TreeID: 1, TreeLength: 17}, + {TreeID: 2, TreeLength: 1}, + {TreeID: 3, TreeLength: 100}, + }, + active: LogRange{TreeID: 4}, + } + + for _, tt := range []struct { + treeID int64 + wantLogRange LogRange + wantErr bool + }{ + // Active shard + { + treeID: 4, + wantLogRange: LogRange{TreeID: 4}, + wantErr: false, + }, + // One of the inactive shards + { + treeID: 2, + wantLogRange: LogRange{TreeID: 2, TreeLength: 1}, + wantErr: false, + }, + // Missing shard + { + treeID: 100, + wantLogRange: LogRange{}, + wantErr: true, + }, + } { + got, err := lrs.GetLogRangeByTreeID(tt.treeID) + if (err != nil) != tt.wantErr { + t.Errorf("GetLogRangeByTreeID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.wantLogRange, got) { + t.Fatalf("log range did not match: %v, %v", tt.wantLogRange, got) + } + } +} + +func TestLogRanges_PublicKey(t *testing.T) { ranges := LogRanges{ - active: 45, + active: LogRange{TreeID: 45, PemPubKey: "activekey"}, inactive: []LogRange{ { - TreeID: 10, - TreeLength: 10, - decodedPublicKey: "sharding", + TreeID: 10, + TreeLength: 10, + PemPubKey: "sharding10", }, { TreeID: 20, TreeLength: 20, + PemPubKey: "sharding20", }, }, } - activePubKey := "activekey" tests := []struct { description string treeID string @@ -139,11 +237,11 @@ func TestPublicKey(t *testing.T) { }, { description: "tree id with decoded public key", treeID: "10", - expectedPubKey: "sharding", + expectedPubKey: "sharding10", }, { description: "tree id without decoded public key", treeID: "20", - expectedPubKey: "activekey", + expectedPubKey: "sharding20", }, { description: "invalid tree id", treeID: "34", @@ -157,7 +255,7 @@ func TestPublicKey(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - got, err := ranges.PublicKey(activePubKey, test.treeID) + got, err := ranges.PublicKey(test.treeID) if err != nil && !test.shouldErr { t.Fatal(err) } @@ -174,7 +272,7 @@ func TestPublicKey(t *testing.T) { func TestLogRanges_String(t *testing.T) { type fields struct { inactive Ranges - active int64 + active LogRange } tests := []struct { name string @@ -185,7 +283,7 @@ func TestLogRanges_String(t *testing.T) { name: "empty", fields: fields{ inactive: Ranges{}, - active: 0, + active: LogRange{}, }, want: "active=0", }, @@ -198,7 +296,7 @@ func TestLogRanges_String(t *testing.T) { TreeLength: 2, }, }, - active: 3, + active: LogRange{TreeID: 3}, }, want: "1=2,active=3", }, @@ -215,7 +313,7 @@ func TestLogRanges_String(t *testing.T) { TreeLength: 3, }, }, - active: 4, + active: LogRange{TreeID: 4}, }, want: "1=2,2=3,active=4", }, @@ -236,7 +334,7 @@ func TestLogRanges_String(t *testing.T) { func TestLogRanges_TotalInactiveLength(t *testing.T) { type fields struct { inactive Ranges - active int64 + active LogRange } tests := []struct { name string @@ -247,7 +345,7 @@ func TestLogRanges_TotalInactiveLength(t *testing.T) { name: "empty", fields: fields{ inactive: Ranges{}, - active: 0, + active: LogRange{}, }, want: 0, }, @@ -260,7 +358,7 @@ func TestLogRanges_TotalInactiveLength(t *testing.T) { TreeLength: 2, }, }, - active: 3, + active: LogRange{TreeID: 3}, }, want: 2, }, @@ -281,7 +379,7 @@ func TestLogRanges_TotalInactiveLength(t *testing.T) { func TestLogRanges_AllShards(t *testing.T) { type fields struct { inactive Ranges - active int64 + active LogRange } tests := []struct { name string @@ -292,7 +390,7 @@ func TestLogRanges_AllShards(t *testing.T) { name: "empty", fields: fields{ inactive: Ranges{}, - active: 0, + active: LogRange{}, }, want: []int64{0}, }, @@ -305,7 +403,7 @@ func TestLogRanges_AllShards(t *testing.T) { TreeLength: 2, }, }, - active: 3, + active: LogRange{TreeID: 3}, }, want: []int64{3, 1}, }, @@ -322,7 +420,7 @@ func TestLogRanges_AllShards(t *testing.T) { TreeLength: 3, }, }, - active: 4, + active: LogRange{TreeID: 4}, }, want: []int64{4, 1, 2}, }, @@ -340,6 +438,35 @@ func TestLogRanges_AllShards(t *testing.T) { } } +func TestLogRanges_ActiveAndInactive(t *testing.T) { + active := LogRange{ + TreeID: 1, + } + inactive := Ranges{ + { + TreeID: 2, + TreeLength: 123, + }, + { + TreeID: 3, + TreeLength: 456, + }, + } + lr := LogRanges{ + active: active, + inactive: inactive, + } + if lr.NoInactive() { + t.Fatalf("expected inactive shards, got no shards") + } + if !reflect.DeepEqual(active, lr.active) { + t.Fatalf("expected active shards to be equal") + } + if !reflect.DeepEqual(inactive, lr.inactive) { + t.Fatalf("expected inactive shards to be equal") + } +} + func TestLogRangesFromPath(t *testing.T) { type args struct { path string @@ -501,7 +628,7 @@ func TestUpdateRange(t *testing.T) { s.Log.EXPECT().GetLatestSignedLogRoot( gomock.Any(), gomock.Any()).Return(tt.rootResponse, tt.signedLogError).AnyTimes() - got, err := updateRange(tt.args.ctx, s.LogClient, tt.args.r) + got, err := updateRange(tt.args.ctx, s.LogClient, tt.args.r, false) if (err != nil) != tt.wantErr { t.Errorf("updateRange() error = %v, wantErr %v", err, tt.wantErr) @@ -515,10 +642,13 @@ func TestUpdateRange(t *testing.T) { } func TestNewLogRangesWithMock(t *testing.T) { + keyPath, ecdsaSigner, pemPubKey, logID := initializeSigner(t) + sc := signer.SigningConfig{SigningSchemeOrKeyPath: keyPath} + type args struct { ctx context.Context path string - treeID uint + treeID int64 } tests := []struct { name string @@ -533,7 +663,16 @@ func TestNewLogRangesWithMock(t *testing.T) { path: "", treeID: 1, }, - want: LogRanges{}, + want: LogRanges{ + active: LogRange{ + TreeID: 1, + TreeLength: 0, + SigningConfig: sc, + Signer: ecdsaSigner, + PemPubKey: pemPubKey, + LogID: logID, + }, + }, wantErr: false, }, { @@ -558,14 +697,95 @@ func TestNewLogRangesWithMock(t *testing.T) { t.Fatalf("Failed to create mock server: %v", err) } defer fakeServer() - got, err := NewLogRanges(tt.args.ctx, s.LogClient, tt.args.path, tt.args.treeID) + got, err := NewLogRanges(tt.args.ctx, s.LogClient, tt.args.path, tt.args.treeID, sc) if (err != nil) != tt.wantErr { t.Errorf("NewLogRanges() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewLogRanges() got = %v, want %v", got, tt.want) + if !tt.wantErr { + logRangesEqual(t, tt.want, got) } }) } } + +// initializeSigner returns a path to an ECDSA private key, an ECDSA signer, +// PEM-encoded public key, and log ID +func initializeSigner(t *testing.T) (string, signature.Signer, string, string) { + td := t.TempDir() + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + pemPrivKey, err := cryptoutils.MarshalPrivateKeyToPEM(privKey) + if err != nil { + t.Fatal(err) + } + signer, err := signature.LoadECDSASigner(privKey, crypto.SHA256) + if err != nil { + t.Fatal(err) + } + // Encode public key + pubKey, err := signer.PublicKey() + if err != nil { + t.Fatal(err) + } + pemPubKey, err := cryptoutils.MarshalPublicKeyToPEM(pubKey) + if err != nil { + t.Fatal(err) + } + // Calculate log ID + b, err := x509.MarshalPKIXPublicKey(pubKey) + if err != nil { + t.Fatal(err) + } + pubkeyHashBytes := sha256.Sum256(b) + logID := hex.EncodeToString(pubkeyHashBytes[:]) + + keyFile := filepath.Join(td, fmt.Sprintf("%s-ecdsa-key.pem", logID)) + if err := os.WriteFile(keyFile, pemPrivKey, 0644); err != nil { + t.Fatal(err) + } + + return keyFile, signer, string(pemPubKey), logID +} + +func logRangesEqual(t *testing.T, expected, got LogRanges) { + logRangeEqual(t, expected.active, got.active) + if len(expected.inactive) != len(got.inactive) { + t.Fatalf("inactive log ranges are not equal") + } + for i, lr := range expected.inactive { + g := got.inactive[i] + logRangeEqual(t, lr, g) + } +} + +func logRangeEqual(t *testing.T, expected, got LogRange) { + if expected.TreeID != got.TreeID { + t.Fatalf("expected tree ID %v, got %v", expected.TreeID, got.TreeID) + } + if expected.TreeLength != got.TreeLength { + t.Fatalf("expected tree length %v, got %v", expected.TreeLength, got.TreeLength) + } + if !reflect.DeepEqual(expected.SigningConfig, got.SigningConfig) { + t.Fatalf("expected signing config %v, got %v", expected.SigningConfig, got.SigningConfig) + } + expectedPubKey, err := expected.Signer.PublicKey() + if err != nil { + t.Fatal(err) + } + gotPubKey, err := got.Signer.PublicKey() + if err != nil { + t.Fatal(err) + } + if err := cryptoutils.EqualKeys(expectedPubKey, gotPubKey); err != nil { + t.Fatal(err) + } + if expected.PemPubKey != got.PemPubKey { + t.Fatalf("expected public key %v, got %v", expected.PemPubKey, got.PemPubKey) + } + if expected.LogID != got.LogID { + t.Fatalf("expected log ID %v, got %v", expected.LogID, got.LogID) + } +} diff --git a/pkg/signer/signer.go b/pkg/signer/signer.go index 93d868d4e..d230f1f80 100644 --- a/pkg/signer/signer.go +++ b/pkg/signer/signer.go @@ -32,6 +32,19 @@ import ( _ "github.com/sigstore/sigstore/pkg/signature/kms/hashivault" ) +// SigningConfig initializes the signer for a specific shard +type SigningConfig struct { + SigningSchemeOrKeyPath string `json:"signingSchemeOrKeyPath" yaml:"signingSchemeOrKeyPath"` + FileSignerPassword string `json:"fileSignerPassword" yaml:"fileSignerPassword"` + TinkKEKURI string `json:"tinkKEKURI" yaml:"tinkKEKURI"` + TinkKeysetPath string `json:"tinkKeysetPath" yaml:"tinkKeysetPath"` +} + +func (sc SigningConfig) IsUnset() bool { + return sc.SigningSchemeOrKeyPath == "" && sc.FileSignerPassword == "" && + sc.TinkKEKURI == "" && sc.TinkKeysetPath == "" +} + func New(ctx context.Context, signer, pass, tinkKEKURI, tinkKeysetPath string) (signature.Signer, error) { switch { case slices.ContainsFunc(kms.SupportedProviders(), diff --git a/pkg/signer/signer_test.go b/pkg/signer/signer_test.go new file mode 100644 index 000000000..8e2b599ce --- /dev/null +++ b/pkg/signer/signer_test.go @@ -0,0 +1,24 @@ +// Copyright 2025 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signer + +import "testing" + +func TestSigningConfig(t *testing.T) { + sc := SigningConfig{} + if !sc.IsUnset() { + t.Fatalf("expected empty signing config to be unset") + } +} diff --git a/pkg/signer/tink.go b/pkg/signer/tink.go index b135db591..28b789bf0 100644 --- a/pkg/signer/tink.go +++ b/pkg/signer/tink.go @@ -17,6 +17,7 @@ package signer import ( "context" "errors" + "fmt" "os" "path/filepath" "strings" @@ -36,6 +37,9 @@ const TinkScheme = "tink" // NewTinkSignerWithHandle returns a signature.SignerVerifier that wraps crypto.Signer and a hash function. // Provide a path to the encrypted keyset and cloud KMS key URI for decryption func NewTinkSigner(ctx context.Context, kekURI, keysetPath string) (signature.Signer, error) { + if kekURI == "" || keysetPath == "" { + return nil, fmt.Errorf("key encryption key URI or keyset path unset") + } kek, err := getKeyEncryptionKey(ctx, kekURI) if err != nil { return nil, err diff --git a/tests/sharding-e2e-test.sh b/tests/sharding-e2e-test.sh index f74117219..71a3d2b46 100755 --- a/tests/sharding-e2e-test.sh +++ b/tests/sharding-e2e-test.sh @@ -133,9 +133,6 @@ echo "the new shard ID is $SHARD_TREE_ID" # Once more $REKOR_CLI loginfo --rekor_server http://localhost:3000 -# Get the public key for the active tree for later -ENCODED_PUBLIC_KEY=$(curl http://localhost:3000/api/v1/log/publicKey | base64 -w 0) - # Spin down the rekor server echo "stopping the rekor server..." REKOR_CONTAINER_ID=$(docker ps --filter name=rekor-server --format {{.ID}}) @@ -143,10 +140,12 @@ docker stop $REKOR_CONTAINER_ID # Now we want to spin up the Rekor server again, but this time point # to the new tree +# New shard will have a different signing key. SHARDING_CONFIG=sharding-config.yaml cat << EOF > $SHARDING_CONFIG - treeID: $INITIAL_TREE_ID - encodedPublicKey: $ENCODED_PUBLIC_KEY + signingConfig: + signingSchemeOrKeyPath: memory EOF cat $SHARDING_CONFIG @@ -226,18 +225,16 @@ $REKOR_CLI logproof --last-size 2 --tree-id $INITIAL_TREE_ID --rekor_server http # And the logproof for the now active shard $REKOR_CLI logproof --last-size 1 --rekor_server http://localhost:3000 +# Make sure the shard keys are different echo "Getting public key for inactive shard..." -GOT_PUB_KEY=$(curl "http://localhost:3000/api/v1/log/publicKey?treeID=$INITIAL_TREE_ID" | base64 -w 0) -echo "Got encoded public key $GOT_PUB_KEY, making sure this matches the public key we got earlier..." -stringsMatch $ENCODED_PUBLIC_KEY $GOT_PUB_KEY - +INACTIVE_PUB_KEY=$(curl "http://localhost:3000/api/v1/log/publicKey?treeID=$INITIAL_TREE_ID" | base64 -w 0) echo "Getting the public key for the active tree..." NEW_PUB_KEY=$(curl "http://localhost:3000/api/v1/log/publicKey" | base64 -w 0) echo "Making sure the public key for the active shard is different from the inactive shard..." -if [[ "$ENCODED_PUBLIC_KEY" == "$NEW_PUB_KEY" ]]; then +if [[ "$INACTIVE_PUB_KEY" == "$NEW_PUB_KEY" ]]; then echo echo "Active tree public key should be different from inactive shard public key but isn't..." - echo "Inactive Shard Public Key: $ENCODED_PUBLIC_KEY" + echo "Inactive Shard Public Key: $INACTIVE_PUB_KEY" echo "Active Shard Public Key: $NEW_PUB_KEY" exit 1 fi