diff --git a/ca/ca.go b/ca/ca.go index d2d48e55826..87a6fc52c70 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -15,7 +15,6 @@ import ( "fmt" "math/big" mrand "math/rand/v2" - "strings" "time" ct "github.com/google/certificate-transparency-go" @@ -51,6 +50,25 @@ const ( certType = certificateType("certificate") ) +// issuanceEvent is logged before and after issuance of precertificates and certificates. +// The `omitempty` fields are not always present. +// CSR, Precertificate, and Certificate are hex-encoded DER bytes to make it easier to +// ad-hoc search for sequences or OIDs in logs. Other data, like public key within CSR, +// is logged as base64 because it doesn't have interesting DER structure. +type issuanceEvent struct { + CSR string `json:",omitempty"` + IssuanceRequest *issuance.IssuanceRequest + Issuer string + OrderID int64 + Profile string + ProfileHash string + Requester int64 + Result struct { + Precertificate string `json:",omitempty"` + Certificate string `json:",omitempty"` + } +} + // Two maps of keys to Issuers. Lookup by PublicKeyAlgorithm is useful for // determining the set of issuers which can sign a given (pre)cert, based on its // PublicKeyAlgorithm. Lookup by NameID is useful for looking up a specific @@ -428,17 +446,22 @@ func (ca *certificateAuthorityImpl) IssueCertificateForPrecertificate(ctx contex return nil, err } - names := strings.Join(issuanceReq.DNSNames, ", ") - ca.log.AuditInfof("Signing cert: issuer=[%s] serial=[%s] regID=[%d] names=[%s] certProfileName=[%s] certProfileHash=[%x] precert=[%s]", - issuer.Name(), serialHex, req.RegistrationID, names, certProfile.name, certProfile.hash, hex.EncodeToString(precert.Raw)) - lintCertBytes, issuanceToken, err := issuer.Prepare(certProfile.profile, issuanceReq) if err != nil { - ca.log.AuditErrf("Preparing cert failed: issuer=[%s] serial=[%s] regID=[%d] names=[%s] certProfileName=[%s] certProfileHash=[%x] err=[%v]", - issuer.Name(), serialHex, req.RegistrationID, names, certProfile.name, certProfile.hash, err) + ca.log.AuditErrf("Preparing cert failed: serial=[%s] err=[%v]", serialHex, err) return nil, berrors.InternalServerError("failed to prepare certificate signing: %s", err) } + logEvent := issuanceEvent{ + IssuanceRequest: issuanceReq, + Issuer: issuer.Name(), + OrderID: req.OrderID, + Profile: certProfile.name, + ProfileHash: hex.EncodeToString(certProfile.hash[:]), + Requester: req.RegistrationID, + } + ca.log.AuditObject("Signing cert", logEvent) + _, span := ca.tracer.Start(ctx, "signing cert", trace.WithAttributes( attribute.String("serial", serialHex), attribute.String("issuer", issuer.Name()), @@ -448,8 +471,7 @@ func (ca *certificateAuthorityImpl) IssueCertificateForPrecertificate(ctx contex certDER, err := issuer.Issue(issuanceToken) if err != nil { ca.metrics.noteSignError(err) - ca.log.AuditErrf("Signing cert failed: issuer=[%s] serial=[%s] regID=[%d] names=[%s] certProfileName=[%s] certProfileHash=[%x] err=[%v]", - issuer.Name(), serialHex, req.RegistrationID, names, certProfile.name, certProfile.hash, err) + ca.log.AuditErrf("Signing cert failed: serial=[%s] err=[%v]", serialHex, err) span.SetStatus(codes.Error, err.Error()) span.End() return nil, berrors.InternalServerError("failed to sign certificate: %s", err) @@ -462,8 +484,8 @@ func (ca *certificateAuthorityImpl) IssueCertificateForPrecertificate(ctx contex } ca.metrics.signatureCount.With(prometheus.Labels{"purpose": string(certType), "issuer": issuer.Name()}).Inc() - ca.log.AuditInfof("Signing cert success: issuer=[%s] serial=[%s] regID=[%d] names=[%s] certificate=[%s] certProfileName=[%s] certProfileHash=[%x]", - issuer.Name(), serialHex, req.RegistrationID, names, hex.EncodeToString(certDER), certProfile.name, certProfile.hash) + logEvent.Result.Certificate = hex.EncodeToString(certDER) + ca.log.AuditObject("Signing cert success", logEvent) _, err = ca.sa.AddCertificate(ctx, &sapb.AddCertificateRequest{ Der: certDER, @@ -471,8 +493,7 @@ func (ca *certificateAuthorityImpl) IssueCertificateForPrecertificate(ctx contex Issued: timestamppb.New(ca.clk.Now()), }) if err != nil { - ca.log.AuditErrf("Failed RPC to store at SA: issuer=[%s] serial=[%s] cert=[%s] regID=[%d] orderID=[%d] certProfileName=[%s] certProfileHash=[%x] err=[%v]", - issuer.Name(), serialHex, hex.EncodeToString(certDER), req.RegistrationID, req.OrderID, certProfile.name, certProfile.hash, err) + ca.log.AuditErrf("Failed RPC to store at SA: serial=[%s] err=[%v]", serialHex, hex.EncodeToString(certDER)) return nil, err } @@ -568,7 +589,7 @@ func (ca *certificateAuthorityImpl) issuePrecertificateInner(ctx context.Context names := csrlib.NamesFromCSR(csr) req := &issuance.IssuanceRequest{ - PublicKey: csr.PublicKey, + PublicKey: issuance.MarshalablePublicKey{PublicKey: csr.PublicKey}, SubjectKeyId: subjectKeyId, Serial: serialBigInt.Bytes(), DNSNames: names.SANs, @@ -579,19 +600,20 @@ func (ca *certificateAuthorityImpl) issuePrecertificateInner(ctx context.Context NotAfter: notAfter, } - ca.log.AuditInfof("Signing precert: serial=[%s] regID=[%d] names=[%s] csr=[%s]", - serialHex, issueReq.RegistrationID, strings.Join(req.DNSNames, ", "), hex.EncodeToString(csr.Raw)) - lintCertBytes, issuanceToken, err := issuer.Prepare(certProfile.profile, req) if err != nil { - ca.log.AuditErrf("Preparing precert failed: issuer=[%s] serial=[%s] regID=[%d] names=[%s] certProfileName=[%s] certProfileHash=[%x] err=[%v]", - issuer.Name(), serialHex, issueReq.RegistrationID, strings.Join(req.DNSNames, ", "), certProfile.name, certProfile.hash, err) + ca.log.AuditErrf("Preparing precert failed: serial=[%s] err=[%v]", serialHex, err) if errors.Is(err, linter.ErrLinting) { ca.metrics.lintErrorCount.Inc() } return nil, nil, berrors.InternalServerError("failed to prepare precertificate signing: %s", err) } + // Note: we write the linting certificate bytes to this table, rather than the precertificate + // (which we audit log but do not put in the database). This is to ensure that even if there is + // an error immediately after signing the precertificate, we have a record in the DB of what we + // intended to sign, and can do revocations based on that. See #6807. + // The name of the SA method ("AddPrecertificate") is a historical artifact. _, err = ca.sa.AddPrecertificate(context.Background(), &sapb.AddCertificateRequest{ Der: lintCertBytes, RegID: issueReq.RegistrationID, @@ -603,6 +625,17 @@ func (ca *certificateAuthorityImpl) issuePrecertificateInner(ctx context.Context return nil, nil, err } + logEvent := issuanceEvent{ + CSR: hex.EncodeToString(csr.Raw), + IssuanceRequest: req, + Issuer: issuer.Name(), + Profile: certProfile.name, + ProfileHash: hex.EncodeToString(certProfile.hash[:]), + Requester: issueReq.RegistrationID, + OrderID: issueReq.OrderID, + } + ca.log.AuditObject("Signing precert", logEvent) + _, span := ca.tracer.Start(ctx, "signing precert", trace.WithAttributes( attribute.String("serial", serialHex), attribute.String("issuer", issuer.Name()), @@ -612,8 +645,7 @@ func (ca *certificateAuthorityImpl) issuePrecertificateInner(ctx context.Context certDER, err := issuer.Issue(issuanceToken) if err != nil { ca.metrics.noteSignError(err) - ca.log.AuditErrf("Signing precert failed: issuer=[%s] serial=[%s] regID=[%d] names=[%s] certProfileName=[%s] certProfileHash=[%x] err=[%v]", - issuer.Name(), serialHex, issueReq.RegistrationID, strings.Join(req.DNSNames, ", "), certProfile.name, certProfile.hash, err) + ca.log.AuditErrf("Signing precert failed: serial=[%s] err=[%v]", serialHex, err) span.SetStatus(codes.Error, err.Error()) span.End() return nil, nil, berrors.InternalServerError("failed to sign precertificate: %s", err) @@ -626,8 +658,11 @@ func (ca *certificateAuthorityImpl) issuePrecertificateInner(ctx context.Context } ca.metrics.signatureCount.With(prometheus.Labels{"purpose": string(precertType), "issuer": issuer.Name()}).Inc() - ca.log.AuditInfof("Signing precert success: issuer=[%s] serial=[%s] regID=[%d] names=[%s] precert=[%s] certProfileName=[%s] certProfileHash=[%x]", - issuer.Name(), serialHex, issueReq.RegistrationID, strings.Join(req.DNSNames, ", "), hex.EncodeToString(certDER), certProfile.name, certProfile.hash) + + logEvent.Result.Precertificate = hex.EncodeToString(certDER) + // The CSR is big and not that informative, so don't log it a second time. + logEvent.CSR = "" + ca.log.AuditObject("Signing precert success", logEvent) return certDER, &certProfileWithID{certProfile.name, certProfile.hash, nil}, nil } diff --git a/ca/ca_test.go b/ca/ca_test.go index 14356066ea6..c96de987669 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -332,7 +332,6 @@ func TestIssuePrecertificate(t *testing.T) { var certDER []byte response, err := ca.IssuePrecertificate(ctx, issueReq) - test.AssertNotError(t, err, "Failed to issue precertificate") certDER = response.DER diff --git a/cmd/admin/key.go b/cmd/admin/key.go index 66da63ebeef..250eb0225d3 100644 --- a/cmd/admin/key.go +++ b/cmd/admin/key.go @@ -3,7 +3,9 @@ package main import ( "bufio" "context" + "crypto/x509" "encoding/hex" + "encoding/pem" "errors" "flag" "fmt" @@ -26,9 +28,14 @@ import ( type subcommandBlockKey struct { parallelism uint comment string - privKey string - spkiFile string - certFile string + + privKey string + spkiFile string + certFile string + csrFile string + csrFileExpectedCN string + + checkSignature bool } var _ subcommand = (*subcommandBlockKey)(nil) @@ -46,6 +53,10 @@ func (s *subcommandBlockKey) Flags(flag *flag.FlagSet) { flag.StringVar(&s.privKey, "private-key", "", "Block issuance for the pubkey corresponding to this private key") flag.StringVar(&s.spkiFile, "spki-file", "", "Block issuance for all keys listed in this file as SHA256 hashes of SPKI, hex encoded, one per line") flag.StringVar(&s.certFile, "cert-file", "", "Block issuance for the public key of the single PEM-formatted certificate in this file") + flag.StringVar(&s.csrFile, "csr-file", "", "Block issuance for the public key of the single PEM-formatted CSR in this file") + flag.StringVar(&s.csrFileExpectedCN, "csr-file-expected-cn", "The key that signed this CSR has been publicly disclosed. It should not be used for any purpose.", "The Subject CN of a CSR will be verified to match this before blocking") + + flag.BoolVar(&s.checkSignature, "check-signature", true, "Check self-signature of CSR before revoking") } func (s *subcommandBlockKey) Run(ctx context.Context, a *admin) error { @@ -56,6 +67,7 @@ func (s *subcommandBlockKey) Run(ctx context.Context, a *admin) error { "-private-key": s.privKey != "", "-spki-file": s.spkiFile != "", "-cert-file": s.certFile != "", + "-csr-file": s.csrFile != "", } maps.DeleteFunc(setInputs, func(_ string, v bool) bool { return !v }) if len(setInputs) == 0 { @@ -75,6 +87,8 @@ func (s *subcommandBlockKey) Run(ctx context.Context, a *admin) error { spkiHashes, err = a.spkiHashesFromFile(s.spkiFile) case "-cert-file": spkiHashes, err = a.spkiHashesFromCertPEM(s.certFile) + case "-csr-file": + spkiHashes, err = a.spkiHashFromCSRPEM(s.csrFile, s.checkSignature, s.csrFileExpectedCN) default: return errors.New("no recognized input method flag set (this shouldn't happen)") } @@ -146,6 +160,43 @@ func (a *admin) spkiHashesFromCertPEM(filename string) ([][]byte, error) { return [][]byte{spkiHash[:]}, nil } +func (a *admin) spkiHashFromCSRPEM(filename string, checkSignature bool, expectedCN string) ([][]byte, error) { + csrFile, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("reading CSR file %q: %w", filename, err) + } + + data, _ := pem.Decode(csrFile) + if data == nil { + return nil, fmt.Errorf("no PEM data found in %q", filename) + } + + a.log.AuditInfof("Parsing key to block from CSR PEM: %x", data) + + csr, err := x509.ParseCertificateRequest(data.Bytes) + if err != nil { + return nil, fmt.Errorf("parsing CSR %q: %w", filename, err) + } + + if checkSignature { + err = csr.CheckSignature() + if err != nil { + return nil, fmt.Errorf("checking CSR signature: %w", err) + } + } + + if csr.Subject.CommonName != expectedCN { + return nil, fmt.Errorf("Got CSR CommonName %q, expected %q", csr.Subject.CommonName, expectedCN) + } + + spkiHash, err := core.KeyDigest(csr.PublicKey) + if err != nil { + return nil, fmt.Errorf("computing SPKI hash: %w", err) + } + + return [][]byte{spkiHash[:]}, nil +} + func (a *admin) blockSPKIHashes(ctx context.Context, spkiHashes [][]byte, comment string, parallelism uint) error { u, err := user.Current() if err != nil { diff --git a/cmd/admin/key_test.go b/cmd/admin/key_test.go index 0bb19223609..ef4428c0a3c 100644 --- a/cmd/admin/key_test.go +++ b/cmd/admin/key_test.go @@ -68,6 +68,53 @@ func TestSPKIHashesFromFile(t *testing.T) { } } +// The key is the p256 test key from RFC9500 +const goodCSR = ` +-----BEGIN CERTIFICATE REQUEST----- +MIG6MGICAQAwADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABEIlSPiPt4L/teyj +dERSxyoeVY+9b3O+XkjpMjLMRcWxbEzRDEy41bihcTnpSILImSVymTQl9BQZq36Q +pCpJQnKgADAKBggqhkjOPQQDAgNIADBFAiBadw3gvL9IjUfASUTa7MvmkbC4ZCvl +21m1KMwkIx/+CQIhAKvuyfCcdZ0cWJYOXCOb1OavolWHIUzgEpNGUWul6O0s +-----END CERTIFICATE REQUEST----- +` + +// TestCSR checks that we get the correct SPKI from a CSR, even if its signature is invalid +func TestCSR(t *testing.T) { + expectedSPKIHash := "b2b04340cfaee616ec9c2c62d261b208e54bb197498df52e8cadede23ac0ba5e" + + goodCSRFile := path.Join(t.TempDir(), "good.csr") + err := os.WriteFile(goodCSRFile, []byte(goodCSR), 0600) + test.AssertNotError(t, err, "writing good csr") + + a := admin{log: blog.NewMock()} + + goodHash, err := a.spkiHashFromCSRPEM(goodCSRFile, true, "") + test.AssertNotError(t, err, "expected to read CSR") + + if len(goodHash) != 1 { + t.Fatalf("expected to read 1 SPKI from CSR, read %d", len(goodHash)) + } + test.AssertEquals(t, hex.EncodeToString(goodHash[0]), expectedSPKIHash) + + // Flip a bit, in the signature, to make a bad CSR: + badCSR := strings.Replace(goodCSR, "Wul6", "Wul7", 1) + + csrFile := path.Join(t.TempDir(), "bad.csr") + err = os.WriteFile(csrFile, []byte(badCSR), 0600) + test.AssertNotError(t, err, "writing bad csr") + + _, err = a.spkiHashFromCSRPEM(csrFile, true, "") + test.AssertError(t, err, "expected invalid signature") + + badHash, err := a.spkiHashFromCSRPEM(csrFile, false, "") + test.AssertNotError(t, err, "expected to read CSR with bad signature") + + if len(badHash) != 1 { + t.Fatalf("expected to read 1 SPKI from CSR, read %d", len(badHash)) + } + test.AssertEquals(t, hex.EncodeToString(badHash[0]), expectedSPKIHash) +} + // mockSARecordingBlocks is a mock which only implements the AddBlockedKey gRPC // method. type mockSARecordingBlocks struct { diff --git a/core/objects.go b/core/objects.go index a4a5240df37..6732d623104 100644 --- a/core/objects.go +++ b/core/objects.go @@ -297,13 +297,13 @@ func (ch Challenge) StringID() string { type Authorization struct { // An identifier for this authorization, unique across // authorizations and certificates within this instance. - ID string `json:"id,omitempty" db:"id"` + ID string `json:"-" db:"id"` // The identifier for which authorization is being given Identifier identifier.ACMEIdentifier `json:"identifier,omitempty" db:"identifier"` // The registration ID associated with the authorization - RegistrationID int64 `json:"regId,omitempty" db:"registrationID"` + RegistrationID int64 `json:"-" db:"registrationID"` // The status of the validation of this authorization Status AcmeStatus `json:"status,omitempty" db:"status"` diff --git a/features/features.go b/features/features.go index ce677a99ed9..efde340c0d8 100644 --- a/features/features.go +++ b/features/features.go @@ -103,12 +103,26 @@ type Config struct { // This flag should only be used in conjunction with UseKvLimitsForNewOrder. DisableLegacyLimitWrites bool + // PropagateCancels controls whether the WFE and ocsp-responder allows + // cancellation of an inbound request to cancel downstream gRPC and other + // queries. In practice, cancellation of an inbound request is achieved by + // Nginx closing the connection on which the request was happening. This may + // help shed load in overcapacity situations. However, note that in-progress + // database queries (for instance, in the SA) are not cancelled. Database + // queries waiting for an available connection may be cancelled. + PropagateCancels bool + // InsertAuthzsIndividually causes the SA's NewOrderAndAuthzs method to // create each new authz one at a time, rather than using MultiInserter. // Although this is expected to be a performance penalty, it is necessary to // get the AUTO_INCREMENT ID of each new authz without relying on MariaDB's // unique "INSERT ... RETURNING" functionality. InsertAuthzsIndividually bool + + // IncrementRateLimits uses Redis' IncrBy, instead of Set, for rate limit + // accounting. This catches and denies spikes of requests much more + // reliably. + IncrementRateLimits bool } var fMu = new(sync.RWMutex) diff --git a/issuance/cert.go b/issuance/cert.go index b59ec8d8272..884ece7c4df 100644 --- a/issuance/cert.go +++ b/issuance/cert.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" + "encoding/json" "errors" "fmt" "math/big" @@ -104,7 +105,18 @@ func NewProfile(profileConfig *ProfileConfig) (*Profile, error) { return nil, fmt.Errorf("validity period %q is too large", profileConfig.MaxValidityPeriod.Duration) } - lints, err := linter.NewRegistry(profileConfig.IgnoredLints) + // TODO(#7756): These lint names don't yet exist in our current zlint v3.6.0 but exist in v3.6.2. + // In order to upgrade without throwing errors, we need to add these to our ignored lints. + // However, v3.6.0 will error if it sees ignored lints it doesn't recognize. Solution: filter + // out these specific lints. As part of the PR that updates to v3.6.2, we will remove this code. + var ignoredLints []string + for _, lintName := range profileConfig.IgnoredLints { + if lintName != "e_cab_dv_subject_invalid_values" && lintName != "w_ext_subject_key_identifier_not_recommended_subscriber" { + ignoredLints = append(ignoredLints, lintName) + } + } + + lints, err := linter.NewRegistry(ignoredLints) cmd.FailOnError(err, "Failed to create zlint registry") if profileConfig.LintConfig != "" { lintconfig, err := lint.NewConfigFromFile(profileConfig.LintConfig) @@ -142,7 +154,7 @@ func (p *Profile) GenerateValidity(now time.Time) (time.Time, time.Time) { // requestValid verifies the passed IssuanceRequest against the profile. If the // request doesn't match the signing profile an error is returned. func (i *Issuer) requestValid(clk clock.Clock, prof *Profile, req *IssuanceRequest) error { - switch req.PublicKey.(type) { + switch req.PublicKey.PublicKey.(type) { case *rsa.PublicKey, *ecdsa.PublicKey: default: return errors.New("unsupported public key type") @@ -250,12 +262,36 @@ var mustStapleExt = pkix.Extension{ Value: []byte{0x30, 0x03, 0x02, 0x01, 0x05}, } +// MarshalablePublicKey is a wrapper for crypto.PublicKey with a custom JSON +// marshaller that encodes the public key as a DER-encoded SubjectPublicKeyInfo. +type MarshalablePublicKey struct { + crypto.PublicKey +} + +func (pk MarshalablePublicKey) MarshalJSON() ([]byte, error) { + keyDER, err := x509.MarshalPKIXPublicKey(pk.PublicKey) + if err != nil { + return nil, err + } + return json.Marshal(keyDER) +} + +type HexMarshalableBytes []byte + +func (h HexMarshalableBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(fmt.Sprintf("%x", h)) +} + // IssuanceRequest describes a certificate issuance request +// +// It can be marshaled as JSON for logging purposes, though note that sctList and precertDER +// will be omitted from the marshaled output because they are unexported. type IssuanceRequest struct { - PublicKey crypto.PublicKey - SubjectKeyId []byte + // PublicKey is of type MarshalablePublicKey so we can log an IssuanceRequest as a JSON object. + PublicKey MarshalablePublicKey + SubjectKeyId HexMarshalableBytes - Serial []byte + Serial HexMarshalableBytes NotBefore time.Time NotAfter time.Time @@ -283,7 +319,7 @@ type IssuanceRequest struct { type issuanceToken struct { mu sync.Mutex template *x509.Certificate - pubKey any + pubKey MarshalablePublicKey // A pointer to the issuer that created this token. This token may only // be redeemed by the same issuer. issuer *Issuer @@ -324,7 +360,7 @@ func (i *Issuer) Prepare(prof *Profile, req *IssuanceRequest) ([]byte, *issuance } template.DNSNames = req.DNSNames - switch req.PublicKey.(type) { + switch req.PublicKey.PublicKey.(type) { case *rsa.PublicKey: if prof.omitKeyEncipherment { template.KeyUsage = x509.KeyUsageDigitalSignature @@ -360,7 +396,7 @@ func (i *Issuer) Prepare(prof *Profile, req *IssuanceRequest) ([]byte, *issuance // check that the tbsCertificate is properly formed by signing it // with a throwaway key and then linting it using zlint - lintCertBytes, err := i.Linter.Check(template, req.PublicKey, prof.lints) + lintCertBytes, err := i.Linter.Check(template, req.PublicKey.PublicKey, prof.lints) if err != nil { return nil, nil, fmt.Errorf("tbsCertificate linting failed: %w", err) } @@ -395,7 +431,7 @@ func (i *Issuer) Issue(token *issuanceToken) ([]byte, error) { return nil, errors.New("tried to redeem issuance token with the wrong issuer") } - return x509.CreateCertificate(rand.Reader, template, i.Cert.Certificate, token.pubKey, i.Signer) + return x509.CreateCertificate(rand.Reader, template, i.Cert.Certificate, token.pubKey.PublicKey, i.Signer) } // ContainsMustStaple returns true if the provided set of extensions includes @@ -430,7 +466,7 @@ func RequestFromPrecert(precert *x509.Certificate, scts []ct.SignedCertificateTi return nil, errors.New("provided certificate doesn't contain the CT poison extension") } return &IssuanceRequest{ - PublicKey: precert.PublicKey, + PublicKey: MarshalablePublicKey{precert.PublicKey}, SubjectKeyId: precert.SubjectKeyId, Serial: precert.SerialNumber.Bytes(), NotBefore: precert.NotBefore, diff --git a/issuance/cert_test.go b/issuance/cert_test.go index 339a1967e67..108ae76b3b6 100644 --- a/issuance/cert_test.go +++ b/issuance/cert_test.go @@ -83,21 +83,21 @@ func TestRequestValid(t *testing.T) { name: "unsupported key type", issuer: &Issuer{}, profile: &Profile{}, - request: &IssuanceRequest{PublicKey: &dsa.PublicKey{}}, + request: &IssuanceRequest{PublicKey: MarshalablePublicKey{&dsa.PublicKey{}}}, expectedError: "unsupported public key type", }, { name: "inactive (rsa)", issuer: &Issuer{}, profile: &Profile{}, - request: &IssuanceRequest{PublicKey: &rsa.PublicKey{}}, + request: &IssuanceRequest{PublicKey: MarshalablePublicKey{&rsa.PublicKey{}}}, expectedError: "inactive issuer cannot issue precert", }, { name: "inactive (ecdsa)", issuer: &Issuer{}, profile: &Profile{}, - request: &IssuanceRequest{PublicKey: &ecdsa.PublicKey{}}, + request: &IssuanceRequest{PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}}, expectedError: "inactive issuer cannot issue precert", }, { @@ -107,7 +107,7 @@ func TestRequestValid(t *testing.T) { }, profile: &Profile{}, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: []byte{0, 1, 2, 3, 4}, }, expectedError: "unexpected subject key ID length", @@ -119,7 +119,7 @@ func TestRequestValid(t *testing.T) { }, profile: &Profile{}, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, IncludeMustStaple: true, }, @@ -132,7 +132,7 @@ func TestRequestValid(t *testing.T) { }, profile: &Profile{}, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, IncludeCTPoison: true, sctList: []ct.SignedCertificateTimestamp{}, @@ -146,7 +146,7 @@ func TestRequestValid(t *testing.T) { }, profile: &Profile{}, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now().Add(time.Hour), NotAfter: fc.Now(), @@ -162,7 +162,7 @@ func TestRequestValid(t *testing.T) { maxValidity: time.Minute, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now(), NotAfter: fc.Now().Add(time.Hour - time.Second), @@ -178,7 +178,7 @@ func TestRequestValid(t *testing.T) { maxValidity: time.Hour, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now(), NotAfter: fc.Now().Add(time.Hour), @@ -195,7 +195,7 @@ func TestRequestValid(t *testing.T) { maxBackdate: time.Hour, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now().Add(-time.Hour * 2), NotAfter: fc.Now().Add(-time.Hour), @@ -212,7 +212,7 @@ func TestRequestValid(t *testing.T) { maxBackdate: time.Hour, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now().Add(time.Hour), NotAfter: fc.Now().Add(time.Hour * 2), @@ -228,7 +228,7 @@ func TestRequestValid(t *testing.T) { maxValidity: time.Hour * 2, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now(), NotAfter: fc.Now().Add(time.Hour), @@ -245,7 +245,7 @@ func TestRequestValid(t *testing.T) { maxValidity: time.Hour * 2, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now(), NotAfter: fc.Now().Add(time.Hour), @@ -262,7 +262,7 @@ func TestRequestValid(t *testing.T) { maxValidity: time.Hour * 2, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now(), NotAfter: fc.Now().Add(time.Hour), @@ -279,7 +279,7 @@ func TestRequestValid(t *testing.T) { maxValidity: time.Hour * 2, }, request: &IssuanceRequest{ - PublicKey: &ecdsa.PublicKey{}, + PublicKey: MarshalablePublicKey{&ecdsa.PublicKey{}}, SubjectKeyId: goodSKID, NotBefore: fc.Now(), NotAfter: fc.Now().Add(time.Hour), @@ -356,7 +356,7 @@ func TestIssue(t *testing.T) { pk, err := tc.generateFunc() test.AssertNotError(t, err, "failed to generate test key") lintCertBytes, issuanceToken, err := signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -399,7 +399,7 @@ func TestIssueCommonName(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") ir := &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com", "www.example.com"}, @@ -463,7 +463,7 @@ func TestIssueOmissions(t *testing.T) { pk, err := rsa.GenerateKey(rand.Reader, 2048) test.AssertNotError(t, err, "failed to generate test key") _, issuanceToken, err := signer.Prepare(prof, &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -492,7 +492,7 @@ func TestIssueCTPoison(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") _, issuanceToken, err := signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -537,7 +537,7 @@ func TestIssueSCTList(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") _, issuanceToken, err := signer.Prepare(enforceSCTsProfile, &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -601,7 +601,7 @@ func TestIssueMustStaple(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") _, issuanceToken, err := signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -636,7 +636,7 @@ func TestIssueBadLint(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") _, _, err = signer.Prepare(noSkipLintsProfile, &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example-com"}, @@ -665,7 +665,7 @@ func TestIssuanceToken(t *testing.T) { pk, err := rsa.GenerateKey(rand.Reader, 2048) test.AssertNotError(t, err, "failed to generate test key") _, issuanceToken, err := signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -682,7 +682,7 @@ func TestIssuanceToken(t *testing.T) { test.AssertContains(t, err.Error(), "issuance token already redeemed") _, issuanceToken, err = signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -712,7 +712,7 @@ func TestInvalidProfile(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") _, _, err = signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -724,7 +724,7 @@ func TestInvalidProfile(t *testing.T) { test.AssertError(t, err, "Invalid IssuanceRequest") _, _, err = signer.Prepare(defaultProfile(), &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, DNSNames: []string{"example.com"}, @@ -765,7 +765,7 @@ func TestMismatchedProfiles(t *testing.T) { pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) test.AssertNotError(t, err, "failed to generate test key") _, issuanceToken, err := issuer1.Prepare(cnProfile, &IssuanceRequest{ - PublicKey: pk.Public(), + PublicKey: MarshalablePublicKey{pk.Public()}, SubjectKeyId: goodSKID, Serial: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, CommonName: "example.com", diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 4b347e0be0a..83d6752ed1f 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -14,6 +14,7 @@ import ( "github.com/prometheus/client_golang/prometheus" berrors "github.com/letsencrypt/boulder/errors" + "github.com/letsencrypt/boulder/features" ) const ( @@ -274,11 +275,13 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision } batchDecision := allowedDecision newTATs := make(map[string]time.Time) + newBuckets := make(map[string]time.Time) + incrBuckets := make(map[string]increment) txnOutcomes := make(map[Transaction]string) for _, txn := range batch { - tat, exists := tats[txn.bucketKey] - if !exists { + tat, bucketExists := tats[txn.bucketKey] + if !bucketExists { // First request from this client. tat = l.clk.Now() } @@ -293,6 +296,15 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision if d.allowed && (tat != d.newTAT) && txn.spend { // New bucket state should be persisted. newTATs[txn.bucketKey] = d.newTAT + + if bucketExists { + incrBuckets[txn.bucketKey] = increment{ + cost: time.Duration(txn.cost * txn.limit.emissionInterval), + ttl: time.Duration(txn.limit.burstOffset), + } + } else { + newBuckets[txn.bucketKey] = d.newTAT + } } if !txn.spendOnly() { @@ -307,10 +319,28 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision } } - if batchDecision.allowed && len(newTATs) > 0 { - err = l.source.BatchSet(ctx, newTATs) - if err != nil { - return nil, err + if features.Get().IncrementRateLimits { + if batchDecision.allowed { + if len(newBuckets) > 0 { + err = l.source.BatchSet(ctx, newBuckets) + if err != nil { + return nil, err + } + } + + if len(incrBuckets) > 0 { + err = l.source.BatchIncrement(ctx, incrBuckets) + if err != nil { + return nil, err + } + } + } + } else { + if batchDecision.allowed && len(newTATs) > 0 { + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } } } @@ -365,10 +395,11 @@ func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decisio batchDecision := allowedDecision newTATs := make(map[string]time.Time) + incrBuckets := make(map[string]increment) for _, txn := range batch { - tat, exists := tats[txn.bucketKey] - if !exists { + tat, bucketExists := tats[txn.bucketKey] + if !bucketExists { // Ignore non-existent bucket. continue } @@ -382,13 +413,26 @@ func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decisio if d.allowed && tat != d.newTAT { // New bucket state should be persisted. newTATs[txn.bucketKey] = d.newTAT + incrBuckets[txn.bucketKey] = increment{ + cost: time.Duration(-txn.cost * txn.limit.emissionInterval), + ttl: time.Duration(txn.limit.burstOffset), + } } } - if len(newTATs) > 0 { - err = l.source.BatchSet(ctx, newTATs) - if err != nil { - return nil, err + if features.Get().IncrementRateLimits { + if len(incrBuckets) > 0 { + err = l.source.BatchIncrement(ctx, incrBuckets) + if err != nil { + return nil, err + } + } + } else { + if len(newTATs) > 0 { + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } } } return batchDecision, nil diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index e8761bcaeeb..41e89c36ad7 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -4,6 +4,7 @@ import ( "context" "math/rand/v2" "net" + "os" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/letsencrypt/boulder/config" berrors "github.com/letsencrypt/boulder/errors" + "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/test" ) @@ -38,6 +40,19 @@ func newTestTransactionBuilder(t *testing.T) *TransactionBuilder { } func setup(t *testing.T) (context.Context, map[string]*Limiter, *TransactionBuilder, clock.FakeClock, string) { + // Because all test cases in this file are affected by this feature flag, we + // want to run them all both with and without the feature flag. This way, we + // get one set of runs with and one set without. It's difficult to defer + // features.Reset() from the setup func (these tests are parallel); as long + // as this code doesn't test any other features, we don't need to. + // + // N.b. This is fragile. If a test case does call features.Reset(), it will + // not be testing the intended code path. But we expect to clean this up + // quickly. + if os.Getenv("BOULDER_CONFIG_DIR") == "test/config-next" { + features.Set(features.Config{IncrementRateLimits: true}) + } + testCtx := context.Background() clk := clock.NewFake() @@ -304,8 +319,8 @@ func TestLimiter_InitializationViaCheckAndSpend(t *testing.T) { test.AssertEquals(t, d.resetIn, time.Millisecond*50) test.AssertEquals(t, d.retryIn, time.Duration(0)) - // However, that cost should not be spent yet, a 0 cost check should - // tell us that we actually have 19 remaining. + // And that cost should have been spent; a 0 cost check should still + // tell us that we have 19 remaining. d, err = l.Check(testCtx, txn0) test.AssertNotError(t, err, "should not error") test.Assert(t, d.allowed, "should be allowed") diff --git a/ratelimits/source.go b/ratelimits/source.go index 77f43b73961..ec798322544 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -20,6 +20,13 @@ type source interface { // the underlying storage client implementation). BatchSet(ctx context.Context, bucketKeys map[string]time.Time) error + // BatchIncrement updates the TATs for the specified bucketKeys, similar to + // BatchSet. Implementations MUST ensure non-blocking operations by either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchIncrement(ctx context.Context, buckets map[string]increment) error + // Get retrieves the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -45,6 +52,11 @@ type source interface { Delete(ctx context.Context, bucketKey string) error } +type increment struct { + cost time.Duration + ttl time.Duration +} + // inmem is an in-memory implementation of the source interface used for // testing. type inmem struct { @@ -52,6 +64,8 @@ type inmem struct { m map[string]time.Time } +var _ source = (*inmem)(nil) + func newInmem() *inmem { return &inmem{m: make(map[string]time.Time)} } @@ -65,6 +79,15 @@ func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) er return nil } +func (in *inmem) BatchIncrement(_ context.Context, bucketKeys map[string]increment) error { + in.Lock() + defer in.Unlock() + for k, v := range bucketKeys { + in.m[k] = in.m[k].Add(v.cost) + } + return nil +} + func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { in.RLock() defer in.RUnlock() @@ -82,7 +105,7 @@ func (in *inmem) BatchGet(_ context.Context, bucketKeys []string) (map[string]ti for _, k := range bucketKeys { tat, ok := in.m[k] if !ok { - tats[k] = time.Time{} + continue } tats[k] = tat } diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index c8db1c4621e..cc2631c2f17 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -83,7 +83,6 @@ func (r *RedisSource) observeLatency(call string, latency time.Duration, err err // BatchSet stores TATs at the specified bucketKeys using a pipelined Redis // Transaction in order to reduce the number of round-trips to each Redis shard. -// An error is returned if the operation failed and nil otherwise. func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error { start := r.clk.Now() @@ -109,9 +108,35 @@ func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time return nil } -// Get retrieves the TAT at the specified bucketKey. An error is returned if the -// operation failed and nil otherwise. If the bucketKey does not exist, -// ErrBucketNotFound is returned. +// BatchIncrement updates TATs for the specified bucketKeys using a pipelined +// Redis Transaction in order to reduce the number of round-trips to each Redis +// shard. +func (r *RedisSource) BatchIncrement(ctx context.Context, buckets map[string]increment) error { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for bucketKey, incr := range buckets { + pipeline.IncrBy(ctx, bucketKey, incr.cost.Nanoseconds()) + pipeline.Expire(ctx, bucketKey, incr.ttl) + } + _, err := pipeline.Exec(ctx) + if err != nil { + r.observeLatency("batchincrby", r.clk.Since(start), err) + return err + } + + totalLatency := r.clk.Since(start) + perSetLatency := totalLatency / time.Duration(len(buckets)) + for range buckets { + r.observeLatency("batchincrby_entry", perSetLatency, nil) + } + + r.observeLatency("batchincrby", totalLatency, nil) + return nil +} + +// Get retrieves the TAT at the specified bucketKey. If the bucketKey does not +// exist, ErrBucketNotFound is returned. func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) { start := r.clk.Now() @@ -133,8 +158,8 @@ func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, err // BatchGet retrieves the TATs at the specified bucketKeys using a pipelined // Redis Transaction in order to reduce the number of round-trips to each Redis -// shard. An error is returned if the operation failed and nil otherwise. If a -// bucketKey does not exist, it WILL NOT be included in the returned map. +// shard. If a bucketKey does not exist, it WILL NOT be included in the returned +// map. func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) { start := r.clk.Now() @@ -184,9 +209,8 @@ func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[st return tats, nil } -// Delete deletes the TAT at the specified bucketKey ('name:id'). It returns an -// error if the operation failed and nil otherwise. A nil return value does not -// indicate that the bucketKey existed. +// Delete deletes the TAT at the specified bucketKey ('name:id'). A nil return +// value does not indicate that the bucketKey existed. func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error { start := r.clk.Now() @@ -201,7 +225,7 @@ func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error { } // Ping checks that each shard of the *redis.Ring is reachable using the PING -// command. It returns an error if any shard is unreachable and nil otherwise. +// command. func (r *RedisSource) Ping(ctx context.Context) error { start := r.clk.Now() diff --git a/ratelimits/source_redis_test.go b/ratelimits/source_redis_test.go index 11ed2715853..249a9dbe2c6 100644 --- a/ratelimits/source_redis_test.go +++ b/ratelimits/source_redis_test.go @@ -77,15 +77,16 @@ func TestRedisSource_BatchSetAndGet(t *testing.T) { "shard2": "10.33.33.5:4218", }) - now := clk.Now() - val1 := now.Add(time.Second) - val2 := now.Add(time.Second * 2) - val3 := now.Add(time.Second * 3) - set := map[string]time.Time{ - "test1": val1, - "test2": val2, - "test3": val3, + "test1": clk.Now().Add(time.Second), + "test2": clk.Now().Add(time.Second * 2), + "test3": clk.Now().Add(time.Second * 3), + } + + incr := map[string]increment{ + "test1": {time.Second, time.Minute}, + "test2": {time.Second * 2, time.Minute}, + "test3": {time.Second * 3, time.Minute}, } err := s.BatchSet(context.Background(), set) @@ -95,7 +96,17 @@ func TestRedisSource_BatchSetAndGet(t *testing.T) { test.AssertNotError(t, err, "BatchGet() should not error") for k, v := range set { - test.Assert(t, got[k].Equal(v), "BatchGet() should return the values set by BatchSet()") + test.AssertEquals(t, got[k], v) + } + + err = s.BatchIncrement(context.Background(), incr) + test.AssertNotError(t, err, "BatchIncrement() should not error") + + got, err = s.BatchGet(context.Background(), []string{"test1", "test2", "test3"}) + test.AssertNotError(t, err, "BatchGet() should not error") + + for k := range set { + test.AssertEquals(t, got[k], set[k].Add(incr[k].cost)) } // Test that BatchGet() returns a zero time for a key that does not exist. diff --git a/sa/database.go b/sa/database.go index ba3b7300375..9d83875802e 100644 --- a/sa/database.go +++ b/sa/database.go @@ -283,7 +283,7 @@ func initTables(dbMap *borp.DbMap) { dbMap.AddTableWithName(authzModel{}, "authz2").SetKeys(true, "ID") dbMap.AddTableWithName(orderToAuthzModel{}, "orderToAuthz2").SetKeys(false, "OrderID", "AuthzID") dbMap.AddTableWithName(recordedSerialModel{}, "serials").SetKeys(true, "ID") - dbMap.AddTableWithName(precertificateModel{}, "precertificates").SetKeys(true, "ID") + dbMap.AddTableWithName(lintingCertModel{}, "precertificates").SetKeys(true, "ID") dbMap.AddTableWithName(keyHashModel{}, "keyHashToSerial").SetKeys(true, "ID") dbMap.AddTableWithName(incidentModel{}, "incidents").SetKeys(true, "ID") dbMap.AddTable(incidentSerialModel{}) diff --git a/sa/db/boulder_sa/20230419000000_CombinedSchema.sql b/sa/db/boulder_sa/20230419000000_CombinedSchema.sql index 34d6f151cee..ff8e5432079 100644 --- a/sa/db/boulder_sa/20230419000000_CombinedSchema.sql +++ b/sa/db/boulder_sa/20230419000000_CombinedSchema.sql @@ -173,6 +173,9 @@ CREATE TABLE `orders` ( PARTITION BY RANGE(id) (PARTITION p_start VALUES LESS THAN (MAXVALUE)); +-- Note: This table's name is a historical artifact and it is now +-- used to store linting certificates, not precertificates. +-- See #6807. CREATE TABLE `precertificates` ( `id` bigint(20) NOT NULL AUTO_INCREMENT, `registrationID` bigint(20) NOT NULL, diff --git a/sa/model.go b/sa/model.go index fa3ce717a29..522d4a52d0f 100644 --- a/sa/model.go +++ b/sa/model.go @@ -160,7 +160,7 @@ const precertFields = "registrationID, serial, der, issued, expires" // SelectPrecertificate selects all fields of one precertificate object // identified by serial. func SelectPrecertificate(ctx context.Context, s db.OneSelector, serial string) (core.Certificate, error) { - var model precertificateModel + var model lintingCertModel err := s.SelectOne( ctx, &model, @@ -384,7 +384,7 @@ type recordedSerialModel struct { Expires time.Time } -type precertificateModel struct { +type lintingCertModel struct { ID int64 Serial string RegistrationID int64 diff --git a/sa/sa.go b/sa/sa.go index 90428a4f50a..18320f767a9 100644 --- a/sa/sa.go +++ b/sa/sa.go @@ -333,7 +333,11 @@ func (ssa *SQLStorageAuthority) SetCertificateStatusReady(ctx context.Context, r return &emptypb.Empty{}, nil } -// AddPrecertificate writes a record of a precertificate generation to the DB. +// AddPrecertificate writes a record of a linting certificate to the database. +// +// Note: The name "AddPrecertificate" is a historical artifact, and this is now +// always called with a linting certificate. See #6807. +// // Note: this is not idempotent: it does not protect against inserting the same // certificate multiple times. Calling code needs to first insert the cert's // serial into the Serials table to ensure uniqueness. @@ -348,7 +352,7 @@ func (ssa *SQLStorageAuthority) AddPrecertificate(ctx context.Context, req *sapb } serialHex := core.SerialToString(parsed.SerialNumber) - preCertModel := &precertificateModel{ + preCertModel := &lintingCertModel{ Serial: serialHex, RegistrationID: req.RegID, DER: req.Der, diff --git a/test/config-next/ra.json b/test/config-next/ra.json index ad320cab78a..c116c1fb7ff 100644 --- a/test/config-next/ra.json +++ b/test/config-next/ra.json @@ -130,7 +130,8 @@ }, "features": { "AsyncFinalize": true, - "UseKvLimitsForNewOrder": true + "UseKvLimitsForNewOrder": true, + "IncrementRateLimits": true }, "ctLogs": { "stagger": "500ms", diff --git a/test/config-next/wfe2.json b/test/config-next/wfe2.json index 78977574241..dbff6ddb1a6 100644 --- a/test/config-next/wfe2.json +++ b/test/config-next/wfe2.json @@ -127,9 +127,11 @@ "Overrides": "test/config-next/wfe2-ratelimit-overrides.yml" }, "features": { + "PropagateCancels": true, "ServeRenewalInfo": true, "CheckIdentifiersPaused": true, - "UseKvLimitsForNewOrder": true + "UseKvLimitsForNewOrder": true, + "IncrementRateLimits": true }, "certProfiles": { "legacy": "The normal profile you know and love", diff --git a/test/integration/otel_test.go b/test/integration/otel_test.go index f8ac7f92d66..4485390cea3 100644 --- a/test/integration/otel_test.go +++ b/test/integration/otel_test.go @@ -206,10 +206,12 @@ func TestTraces(t *testing.T) { traceID := traceIssuingTestCert(t) wfe := "boulder-wfe2" - sa := "boulder-sa" ra := "boulder-ra" ca := "boulder-ca" + // A very stripped-down version of the expected call graph of a full issuance + // flow: just enough to ensure that our otel tracing is working without + // asserting too much about the exact set of RPCs we use under the hood. expectedSpans := expectedSpans{ Operation: "TraceTest", Service: "integration.test", @@ -218,45 +220,13 @@ func TestTraces(t *testing.T) { {Operation: "/acme/new-nonce", Service: wfe, Children: []expectedSpans{ rpcSpan("nonce.NonceService/Nonce", wfe, "nonce-service")}}, httpSpan("/acme/new-acct", - redisPipelineSpan("get", wfe), - redisPipelineSpan("set", wfe), - rpcSpan("sa.StorageAuthorityReadOnly/KeyBlocked", wfe, sa), - rpcSpan("sa.StorageAuthorityReadOnly/GetRegistrationByKey", wfe, sa), - rpcSpan("ra.RegistrationAuthority/NewRegistration", wfe, ra, - rpcSpan("sa.StorageAuthority/KeyBlocked", ra, sa), - rpcSpan("sa.StorageAuthority/NewRegistration", ra, sa))), - httpSpan("/acme/new-order", - rpcSpan("sa.StorageAuthorityReadOnly/GetRegistration", wfe, sa), - redisPipelineSpan("get", wfe), - redisPipelineSpan("set", wfe), - rpcSpan("ra.RegistrationAuthority/NewOrder", wfe, ra, - rpcSpan("sa.StorageAuthority/GetOrderForNames", ra, sa), - rpcSpan("sa.StorageAuthority/NewOrderAndAuthzs", ra, sa))), - httpSpan("/acme/authz-v3/", - rpcSpan("ra.RegistrationAuthority/GetAuthorization", wfe, ra, - rpcSpan("sa.StorageAuthority/GetAuthorization2", ra, sa))), - httpSpan("/acme/chall-v3/", - rpcSpan("ra.RegistrationAuthority/GetAuthorization", wfe, ra, - rpcSpan("sa.StorageAuthority/GetAuthorization2", ra, sa)), - rpcSpan("ra.RegistrationAuthority/PerformValidation", wfe, ra, - rpcSpan("sa.StorageAuthority/GetRegistration", ra, sa))), + redisPipelineSpan("get", wfe)), + httpSpan("/acme/new-order"), + httpSpan("/acme/authz-v3/"), + httpSpan("/acme/chall-v3/"), httpSpan("/acme/finalize/", - rpcSpan("sa.StorageAuthorityReadOnly/GetOrder", wfe, sa), rpcSpan("ra.RegistrationAuthority/FinalizeOrder", wfe, ra, - rpcSpan("sa.StorageAuthority/KeyBlocked", ra, sa), - rpcSpan("sa.StorageAuthority/GetRegistration", ra, sa), - rpcSpan("sa.StorageAuthority/GetValidOrderAuthorizations2", ra, sa), - rpcSpan("sa.StorageAuthority/SetOrderProcessing", ra, sa), - rpcSpan("ca.CertificateAuthority/IssuePrecertificate", ra, ca), - redisPipelineSpan("get", ra), - rpcSpan("Publisher/SubmitToSingleCTWithResult", ra, "boulder-publisher"), - rpcSpan("ca.CertificateAuthority/IssueCertificateForPrecertificate", ra, ca), - redisPipelineSpan("set", ra), - rpcSpan("sa.StorageAuthority/FinalizeOrder", ra, sa))), - httpSpan("/acme/order/", - rpcSpan("sa.StorageAuthorityReadOnly/GetOrder", wfe, sa)), - httpSpan("/acme/cert/", - rpcSpan("sa.StorageAuthorityReadOnly/GetCertificate", wfe, sa)), + rpcSpan("ca.CertificateAuthority/IssueCertificateForPrecertificate", ra, ca))), }, } diff --git a/web/context.go b/web/context.go index 24943858947..a748137a0aa 100644 --- a/web/context.go +++ b/web/context.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" ) @@ -127,11 +128,13 @@ func (th *TopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Origin: r.Header.Get("Origin"), Extra: make(map[string]interface{}), } - // We specifically override the default r.Context() because we would prefer - // for clients to not be able to cancel our operations in arbitrary places. - // Instead we start a new context, and apply timeouts in our various RPCs. - ctx := context.WithoutCancel(r.Context()) - r = r.WithContext(ctx) + if !features.Get().PropagateCancels { + // We specifically override the default r.Context() because we would prefer + // for clients to not be able to cancel our operations in arbitrary places. + // Instead we start a new context, and apply timeouts in our various RPCs. + ctx := context.WithoutCancel(r.Context()) + r = r.WithContext(ctx) + } // Some clients will send a HTTP Host header that includes the default port // for the scheme that they are using. Previously when we were fronted by diff --git a/web/context_test.go b/web/context_test.go index a5e806c557c..ed98597cdc0 100644 --- a/web/context_test.go +++ b/web/context_test.go @@ -2,13 +2,16 @@ package web import ( "bytes" + "context" "crypto/tls" "fmt" "net/http" "net/http/httptest" "strings" "testing" + "time" + "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/test" ) @@ -117,3 +120,36 @@ func TestHostHeaderRewrite(t *testing.T) { req.Host = "localhost:123" th.ServeHTTP(httptest.NewRecorder(), req) } + +type cancelHandler struct { + res chan string +} + +func (ch cancelHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + ch.res <- r.Context().Err().Error() + case <-time.After(300 * time.Millisecond): + ch.res <- "300 ms passed" + } +} + +func TestPropagateCancel(t *testing.T) { + mockLog := blog.UseMock() + res := make(chan string) + features.Set(features.Config{PropagateCancels: true}) + th := NewTopHandler(mockLog, cancelHandler{res}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + req, err := http.NewRequestWithContext(ctx, "GET", "/thisisignored", &bytes.Reader{}) + if err != nil { + t.Error(err) + } + th.ServeHTTP(httptest.NewRecorder(), req) + }() + cancel() + result := <-res + if result != "context canceled" { + t.Errorf("expected 'context canceled', got %q", result) + } +} diff --git a/wfe2/wfe.go b/wfe2/wfe.go index ed5a371d0cf..0d076a96b60 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -624,6 +624,9 @@ func (wfe *WebFrontEndImpl) sendError(response http.ResponseWriter, logEvent *we } } } + if prob.HTTPStatus == http.StatusInternalServerError { + response.Header().Add(headerRetryAfter, "60") + } wfe.stats.httpErrorCount.With(prometheus.Labels{"type": string(prob.Type)}).Inc() web.SendError(wfe.log, response, logEvent, prob, ierr) } @@ -1207,8 +1210,7 @@ func (wfe *WebFrontEndImpl) prepChallengeForDisplay(request *http.Request, authz } // prepAuthorizationForDisplay takes a core.Authorization and prepares it for -// display to the client by clearing its ID and RegistrationID fields, and -// preparing all its challenges. +// display to the client by preparing all its challenges. func (wfe *WebFrontEndImpl) prepAuthorizationForDisplay(request *http.Request, authz *core.Authorization) { for i := range authz.Challenges { wfe.prepChallengeForDisplay(request, *authz, &authz.Challenges[i]) @@ -1219,9 +1221,6 @@ func (wfe *WebFrontEndImpl) prepAuthorizationForDisplay(request *http.Request, a authz.Challenges[i], authz.Challenges[j] = authz.Challenges[j], authz.Challenges[i] }) - authz.ID = "" - authz.RegistrationID = 0 - // The ACME spec forbids allowing "*" in authorization identifiers. Boulder // allows this internally as a means of tracking when an authorization // corresponds to a wildcard request (e.g. to handle CAA properly). We strip diff --git a/wfe2/wfe_test.go b/wfe2/wfe_test.go index 7699e65d9c4..f7eaf7d2157 100644 --- a/wfe2/wfe_test.go +++ b/wfe2/wfe_test.go @@ -3401,9 +3401,11 @@ func TestPrepAuthzForDisplay(t *testing.T) { // This modifies the authz in-place. wfe.prepAuthorizationForDisplay(&http.Request{Host: "localhost"}, authz) - // The ID and RegID should be empty, since they're not part of the ACME API object. - test.AssertEquals(t, authz.ID, "") - test.AssertEquals(t, authz.RegistrationID, int64(0)) + // Ensure ID and RegID are omitted. + authzJSON, err := json.Marshal(authz) + test.AssertNotError(t, err, "Failed to marshal authz") + test.AssertNotContains(t, string(authzJSON), "\"id\":\"12345\"") + test.AssertNotContains(t, string(authzJSON), "\"registrationID\":\"1\"") } func TestPrepRevokedAuthzForDisplay(t *testing.T) { @@ -3819,6 +3821,15 @@ func Test_sendError(t *testing.T) { test.AssertEquals(t, testResponse.Header().Get("Link"), "") } +func Test_sendErrorInternalServerError(t *testing.T) { + features.Reset() + wfe, _, _ := setupWFE(t) + testResponse := httptest.NewRecorder() + + wfe.sendError(testResponse, &web.RequestEvent{}, probs.ServerInternal("oh no"), nil) + test.AssertEquals(t, testResponse.Header().Get("Retry-After"), "60") +} + type mockSA struct { sapb.StorageAuthorityReadOnlyClient cert *corepb.Certificate