diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..c5c02df --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +custom: https://paypal.me/grepplabs?locale.x=en_GB diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..26d4b64 --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,35 @@ +name: tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + check-latest: true + - run: go version + - name: Vendor + run: go mod vendor + - name: Build + run: go build -v ./... + - name: Vet + run: go vet ./... + - name: Test + run: go test -count=1 -v ./... + - name: golangci-lint + uses: golangci/golangci-lint-action@v4 + with: + version: v1.56.2 + skip-pkg-cache: true + skip-build-cache: true diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..8f84707 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,34 @@ +# options for analysis running +run: + # exit code when at least one issue was found, default is 1 + issues-exit-code: 1 + + # which dirs to skip: they won't be analyzed; + # can use regexp here: generated.*, regexp is applied on full path; + # default value is empty list, but next dirs are always skipped independently + # from this option's value: + # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ + skip-dirs: + - vendor + +linters: + enable: + - errcheck + - goconst + - godot + - gofmt + - goimports + - gosimple + - govet + - ineffassign + - staticcheck + - typecheck + - unparam + - unused + - exportloopref + +issues: + exclude-rules: + - path: _test\.go + linters: + - unparam \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b90b93f --- /dev/null +++ b/Makefile @@ -0,0 +1,38 @@ +.DEFAULT_GOAL := help + +ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) + +default: help + +.PHONY: help +help: + @grep -E '^[a-zA-Z%_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: test +test: ## Test + GO111MODULE=on go test -count=1 -mod=vendor -v ./... + +.PHONY: fmt +fmt: ## Go format + go fmt ./... + +.PHONY: vet +vet: ## Go vet + go vet ./... + +.PHONY: lint +lint: ## Lint + @golangci-lint run + +.PHONY: deps +deps: ## Get dependencies + GO111MODULE=on go get ./... + +.PHONY: vendor +vendor: ## Go vendor + GO111MODULE=on go mod vendor + +.PHONY: tidy +tidy: ## Go tidy + GO111MODULE=on go mod tidy + diff --git a/README.md b/README.md index e69de29..676174a 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,9 @@ +# cert-source + +[![Release](https://img.shields.io/github/v/release/grepplabs/cert-source?sort=semver)](https://github.com/grepplabs/cert-source/releases) +![Build](https://github.com/grepplabs/cert-source/workflows/tests/badge.svg) + +## Overview + +The cert-source is a library designed to help with loading of TLS certificates and to streamline the process of +certificate rotation. diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..62d2461 --- /dev/null +++ b/config/config.go @@ -0,0 +1,31 @@ +package config + +import ( + "time" +) + +type TLSServerConfig struct { + Enable bool `help:"Enable server-side TLS."` + Refresh time.Duration `default:"0s" help:"Interval for refreshing server TLS certificates."` + File TLSServerFiles `embed:"" prefix:"file."` +} + +type TLSServerFiles struct { + Key string `placeholder:"FILE" help:"Path to the server TLS key file."` + Cert string `placeholder:"FILE" help:"Path to the server TLS certificate file."` + ClientCAs string `placeholder:"FILE" name:"client-ca" help:"Optional path to server client CA file for client verification."` + ClientCLR string `placeholder:"FILE" name:"client-clr" help:"TLS X509 CLR signed be the client CA. If no revocation list is specified, only client CA is verified."` +} + +type TLSClientConfig struct { + Enable bool `help:"Enable client-side TLS."` + Refresh time.Duration `default:"0s" help:"Interval for refreshing client TLS certificates."` + InsecureSkipVerify bool `help:"Skip TLS verification on client side."` + File TLSClientFiles `embed:"" prefix:"file."` +} + +type TLSClientFiles struct { + Key string `placeholder:"FILE" help:"Optional path to client TLS key file."` + Cert string `placeholder:"FILE" help:"Optional path to client TLS certificate file."` + RootCAs string `placeholder:"FILE" name:"root-ca" help:"Optional path to client root CAs for server verification."` +} diff --git a/go.mod b/go.mod index bcb038f..4e5eedb 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/grepplabs/cert-source go 1.21 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fa4b6e6 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/testutil/certs.go b/internal/testutil/certs.go new file mode 100644 index 0000000..1bf7393 --- /dev/null +++ b/internal/testutil/certs.go @@ -0,0 +1,308 @@ +package testutil + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + mathrand "math/rand" + "net" + "net/http" + "os" + "reflect" + "time" + + tlsclient "github.com/grepplabs/cert-source/tls/client" +) + +func GenerateCRL(caX509Cert *x509.Certificate, caPrivateKey crypto.PrivateKey, certs []*x509.Certificate, crlFile *os.File) error { + revoked := make([]x509.RevocationListEntry, 0) + for _, cert := range certs { + revoked = append(revoked, x509.RevocationListEntry{ + SerialNumber: cert.SerialNumber, + RevocationTime: time.Now().Add(-1 * time.Minute), + }) + } + template := &x509.RevocationList{ + SignatureAlgorithm: x509.SHA256WithRSA, + RevokedCertificateEntries: revoked, + Number: big.NewInt(mathrand.Int63()), + ThisUpdate: time.Now().Add(-1 * time.Minute), + NextUpdate: time.Now().Add(60 * time.Minute), + } + signer, ok := caPrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("private key %s does not implement signer", reflect.TypeOf(caPrivateKey)) + } + derBytes, err := x509.CreateRevocationList(rand.Reader, template, caX509Cert, signer) + if err != nil { + return err + } + // Public key + err = pem.Encode(crlFile, &pem.Block{Type: "X509 CRL", Bytes: derBytes}) + if err != nil { + return err + } + err = crlFile.Sync() + if err != nil { + return err + } + return nil +} + +func GenerateCert(caCert *tls.Certificate, client bool, certFile *os.File, keyFile *os.File) (*tls.Certificate, *x509.Certificate, error) { + var certificate *x509.Certificate + if client { + certificate = &x509.Certificate{ + SerialNumber: big.NewInt(mathrand.Int63()), + Subject: pkix.Name{ + CommonName: fmt.Sprintf("client-%d", mathrand.Int63()), + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + } else { + certificate = &x509.Certificate{ + SerialNumber: big.NewInt(mathrand.Int63()), + Subject: pkix.Name{ + CommonName: fmt.Sprintf("server-%d", mathrand.Int63()), + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{[]byte{127, 0, 0, 1}}, + } + } + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + ca, err := x509.ParseCertificate(caCert.Certificate[0]) + if err != nil { + return nil, nil, err + } + cert, err := x509.CreateCertificate(rand.Reader, certificate, ca, &priv.PublicKey, caCert.PrivateKey) + if err != nil { + return nil, nil, err + } + // Public key + err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: cert}) + if err != nil { + return nil, nil, err + } + err = certFile.Sync() + if err != nil { + return nil, nil, err + } + // Private key + err = pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + if err != nil { + return nil, nil, err + } + err = keyFile.Sync() + if err != nil { + return nil, nil, err + } + // Load Cert + caTLS, err := tls.LoadX509KeyPair(certFile.Name(), keyFile.Name()) + if err != nil { + return nil, nil, err + } + x509Cert, err := x509.ParseCertificate(caTLS.Certificate[0]) + if err != nil { + return nil, nil, err + } + return &tls.Certificate{ + Certificate: [][]byte{cert}, + PrivateKey: priv, + }, x509Cert, nil +} + +func GenerateCA(certFile *os.File, keyFile *os.File) (*tls.Certificate, *x509.Certificate, error) { + certificate := &x509.Certificate{ + SerialNumber: big.NewInt(mathrand.Int63()), + Subject: pkix.Name{ + CommonName: "ca-cert", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + } + + caPriv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + caCert, err := x509.CreateCertificate(rand.Reader, certificate, certificate, &caPriv.PublicKey, caPriv) + if err != nil { + return nil, nil, err + } + + // Public key + err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCert}) + if err != nil { + return nil, nil, err + } + err = certFile.Sync() + if err != nil { + return nil, nil, err + } + // Private key + err = pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(caPriv)}) + if err != nil { + return nil, nil, err + } + err = keyFile.Sync() + if err != nil { + return nil, nil, err + } + // Load CA + caTLS, err := tls.LoadX509KeyPair(certFile.Name(), keyFile.Name()) + if err != nil { + return nil, nil, err + } + x509Cert, err := x509.ParseCertificate(caTLS.Certificate[0]) + if err != nil { + return nil, nil, err + } + return &tls.Certificate{ + Certificate: [][]byte{caCert}, + PrivateKey: caPriv, + }, x509Cert, nil +} + +type CertsBundle struct { + dirName string + + CACert *os.File + CAKey *os.File + CAEmptyCRL *os.File + CATLSCert *tls.Certificate + CAX509Cert *x509.Certificate + + ServerCert *os.File + ServerKey *os.File + ServerTLSCert *tls.Certificate + ServerX509Cert *x509.Certificate + + ClientCert *os.File + ClientKey *os.File + ClientCRL *os.File + ClientTLSCert *tls.Certificate + ClientX509Cert *x509.Certificate +} + +func (bundle *CertsBundle) Close() { + _ = os.Remove(bundle.CACert.Name()) + _ = os.Remove(bundle.CAKey.Name()) + _ = os.Remove(bundle.ServerCert.Name()) + _ = os.Remove(bundle.ServerKey.Name()) + _ = os.Remove(bundle.ClientCert.Name()) + _ = os.Remove(bundle.ClientKey.Name()) + _ = os.Remove(bundle.dirName) +} + +func (bundle *CertsBundle) NewHttpClient() *http.Client { + return &http.Client{ + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithRootCA(bundle.CAX509Cert), tlsclient.WithClientCertificate(bundle.ClientTLSCert)), + } +} + +func NewCertsBundle() *CertsBundle { + dirName, err := os.MkdirTemp("", "tls-test-") + if err != nil { + panic(err) + } + bundle := &CertsBundle{} + bundle.CACert, err = os.CreateTemp(dirName, "ca-cert-") + if err != nil { + panic(err) + } + defer closeFile(bundle.CACert) + + bundle.CAKey, err = os.CreateTemp(dirName, "ca-key-") + if err != nil { + panic(err) + } + defer closeFile(bundle.CAKey) + + bundle.CAEmptyCRL, err = os.CreateTemp(dirName, "ca-empty-crl") + if err != nil { + panic(err) + } + defer closeFile(bundle.CAEmptyCRL) + + bundle.ServerCert, err = os.CreateTemp(dirName, "server-cert-") + if err != nil { + panic(err) + } + defer closeFile(bundle.ServerCert) + + bundle.ServerKey, err = os.CreateTemp(dirName, "server-key-") + if err != nil { + panic(err) + } + defer closeFile(bundle.ServerKey) + + bundle.ClientCert, err = os.CreateTemp(dirName, "client-cert-") + if err != nil { + panic(err) + } + defer closeFile(bundle.ClientCert) + + bundle.ClientKey, err = os.CreateTemp("", "client-key-") + if err != nil { + panic(err) + } + defer closeFile(bundle.ClientKey) + + bundle.ClientCRL, err = os.CreateTemp("", "client-crl-") + if err != nil { + panic(err) + } + defer closeFile(bundle.ClientCRL) + + // generate certs + bundle.CATLSCert, bundle.CAX509Cert, err = GenerateCA(bundle.CACert, bundle.CAKey) + if err != nil { + panic(err) + } + bundle.ServerTLSCert, bundle.ServerX509Cert, err = GenerateCert(bundle.CATLSCert, false, bundle.ServerCert, bundle.ServerKey) + if err != nil { + panic(err) + } + bundle.ClientTLSCert, bundle.ClientX509Cert, err = GenerateCert(bundle.CATLSCert, true, bundle.ClientCert, bundle.ClientKey) + if err != nil { + panic(err) + } + // generate CRLs + err = GenerateCRL(bundle.CAX509Cert, bundle.CATLSCert.PrivateKey, []*x509.Certificate{}, bundle.CAEmptyCRL) + if err != nil { + panic(err) + } + err = GenerateCRL(bundle.CAX509Cert, bundle.CATLSCert.PrivateKey, []*x509.Certificate{bundle.ClientX509Cert}, bundle.ClientCRL) + if err != nil { + panic(err) + } + return bundle +} + +func closeFile(file *os.File) { + err := file.Close() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + } +} diff --git a/tls/certutil/helper.go b/tls/certutil/helper.go new file mode 100644 index 0000000..38c31c8 --- /dev/null +++ b/tls/certutil/helper.go @@ -0,0 +1,77 @@ +package certutil + +import ( + "bytes" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" +) + +const ( + X509CRLBlockType = "X509 CRL" + CertificateBlockType = "CERTIFICATE" +) + +func ParseCRLsPEM(pemCrls []byte) ([]*x509.RevocationList, error) { + ok := false + var lists []*x509.RevocationList + for len(pemCrls) > 0 { + var block *pem.Block + block, pemCrls = pem.Decode(pemCrls) + if block == nil { + break + } + if block.Type != X509CRLBlockType { + continue + } + list, err := x509.ParseRevocationList(block.Bytes) + if err != nil { + return lists, err + } + lists = append(lists, list) + ok = true + } + if !ok { + return lists, errors.New("data does not contain any valid CRL") + } + return lists, nil +} + +func ParseCertsPEM(pemCerts []byte) ([]*x509.Certificate, error) { + ok := false + var certs []*x509.Certificate + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != CertificateBlockType || len(block.Headers) != 0 { + continue + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return certs, err + } + + certs = append(certs, cert) + ok = true + } + + if !ok { + return certs, errors.New("data does not contain any valid RSA or ECDSA certificates") + } + return certs, nil +} + +func GetHexFormatted(buf []byte, sep string) string { + var ret bytes.Buffer + for _, cur := range buf { + if ret.Len() > 0 { + _, _ = ret.WriteString(sep) + } + _, _ = fmt.Fprintf(&ret, "%02x", cur) + } + return ret.String() +} diff --git a/tls/client/client.go b/tls/client/client.go new file mode 100644 index 0000000..8e74171 --- /dev/null +++ b/tls/client/client.go @@ -0,0 +1,54 @@ +package client + +import ( + "crypto/tls" + "errors" + "log/slog" + "time" + + "github.com/grepplabs/cert-source/tls/client/source" +) + +const ( + initLoadTimeout = 5 * time.Second +) + +type TLSClientConfigFunc func() *tls.Config + +func NewTLSClientConfigFunc(logger *slog.Logger, src source.ClientCertsSource) (TLSClientConfigFunc, error) { + store, err := NewTLSClientCertsStore(logger, src) + if err != nil { + return nil, err + } + return func() *tls.Config { + cs := store.LoadClientCerts() + return &tls.Config{ + RootCAs: cs.RootCAs, + InsecureSkipVerify: cs.InsecureSkipVerify, + GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return store.LoadClientCerts().Certificate, nil + }, + } + }, nil +} + +func NewTLSClientCertsStore(logger *slog.Logger, src source.ClientCertsSource) (*source.ClientCertsStore, error) { + store := source.NewClientCertsStore(logger) + logger.Info("initial client certs loading") + + certsChan := src.ClientCerts() + + select { + case certs := <-certsChan: + store.SetClientCerts(certs) + case <-time.After(initLoadTimeout): + return nil, errors.New("get client certs timeout") + } + + go func() { + for certs := range certsChan { + store.SetClientCerts(certs) + } + }() + return store, nil +} diff --git a/tls/client/config/config.go b/tls/client/config/config.go new file mode 100644 index 0000000..f66281a --- /dev/null +++ b/tls/client/config/config.go @@ -0,0 +1,27 @@ +package config + +import ( + "fmt" + "log/slog" + + "github.com/grepplabs/cert-source/config" + tlsclient "github.com/grepplabs/cert-source/tls/client" + "github.com/grepplabs/cert-source/tls/client/filesource" +) + +func GetTLSClientConfigFunc(logger *slog.Logger, conf *config.TLSClientConfig) (tlsclient.TLSClientConfigFunc, error) { + if !conf.Enable { + return nil, nil + } + fs, err := filesource.New( + filesource.WithLogger(logger.With("tls", "client")), + filesource.WithRefresh(conf.Refresh), + filesource.WithInsecureSkipVerify(conf.InsecureSkipVerify), + filesource.WithClientCert(conf.File.Cert, conf.File.Key), + filesource.WithClientRootCAs(conf.File.RootCAs), + ) + if err != nil { + return nil, fmt.Errorf("setup client cert file source: %w", err) + } + return tlsclient.NewTLSClientConfigFunc(logger, fs) +} diff --git a/tls/client/config/config_test.go b/tls/client/config/config_test.go new file mode 100644 index 0000000..96add83 --- /dev/null +++ b/tls/client/config/config_test.go @@ -0,0 +1,31 @@ +package config + +import ( + "log/slog" + "testing" + + "github.com/grepplabs/cert-source/config" + "github.com/grepplabs/cert-source/internal/testutil" + "github.com/stretchr/testify/require" +) + +func TestGetClientTLSConfig(t *testing.T) { + bundle := testutil.NewCertsBundle() + defer bundle.Close() + tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{ + Enable: true, + Refresh: 0, + File: config.TLSClientFiles{ + Key: bundle.ClientKey.Name(), + Cert: bundle.ClientCert.Name(), + RootCAs: bundle.CACert.Name(), + }, + }) + require.NoError(t, err) + tlsConfig := tlsConfigFunc() + require.NotNil(t, tlsConfig.RootCAs) + + clientCert, err := tlsConfig.GetClientCertificate(nil) + require.NoError(t, err) + require.NotNil(t, clientCert) +} diff --git a/tls/client/filesource/filesource.go b/tls/client/filesource/filesource.go new file mode 100644 index 0000000..97d1d0e --- /dev/null +++ b/tls/client/filesource/filesource.go @@ -0,0 +1,120 @@ +package filesource + +import ( + "errors" + "log/slog" + "os" + "sync/atomic" + "time" + + tlscert "github.com/grepplabs/cert-source/tls/client/source" + "github.com/grepplabs/cert-source/tls/watcher" +) + +type fileSource struct { + insecureSkipVerify bool + certFile string + keyFile string + rootCAsFile string + refresh time.Duration + logger *slog.Logger + notifyFunc func() + lastClientCerts atomic.Pointer[tlscert.ClientCerts] +} + +func New(opts ...Option) (tlscert.ClientCertsSource, error) { + s := &fileSource{ + logger: slog.Default(), + } + for _, opt := range opts { + opt(s) + } + lastClientCerts, err := s.getClientCerts() + if err != nil { + return nil, err + } + s.lastClientCerts.Store(lastClientCerts) + return s, nil +} + +func MustNew(opts ...Option) tlscert.ClientCertsSource { + serverSource, err := New(opts...) + if err != nil { + panic(`filesource: New(): ` + err.Error()) + } + return serverSource +} + +func (s *fileSource) getClientCerts() (*tlscert.ClientCerts, error) { + pemBlocks, err := s.Load() + if err != nil { + return nil, err + } + certificate, err := pemBlocks.Certificate() + if err != nil { + return nil, err + } + rootCAs, err := pemBlocks.RootCAs() + if err != nil { + return nil, err + } + return &tlscert.ClientCerts{ + InsecureSkipVerify: s.insecureSkipVerify, + Certificate: certificate, + RootCAs: rootCAs, + Checksum: pemBlocks.Checksum(), + }, nil +} + +func (s *fileSource) refreshClientCerts() (*tlscert.ClientCerts, error) { + clientCerts, err := s.getClientCerts() + if err != nil { + return nil, err + } + s.lastClientCerts.Store(clientCerts) + return clientCerts, nil +} + +func (s *fileSource) ClientCerts() chan tlscert.ClientCerts { + initialClientCert := s.lastClientCerts.Load() + ch := make(chan tlscert.ClientCerts, 1) + if initialClientCert != nil { + ch <- *initialClientCert + } + if s.refresh <= 0 { + close(ch) + } else { + go func() { + watcher.Watch(s.logger, ch, s.refresh, initialClientCert, s.refreshClientCerts, s.notifyFunc) + close(ch) + }() + } + return ch +} + +func (s *fileSource) Load() (pemBlocks *tlscert.ClientPEMs, err error) { + pemBlocks = &tlscert.ClientPEMs{} + + if (s.certFile == "") != (s.keyFile == "") { + return nil, errors.New("cert file source: both certFile and keyFile must be set or be empty") + } + if s.certFile != "" && s.keyFile != "" { + if pemBlocks.CertPEMBlock, err = s.readFile(s.certFile); err != nil { + return nil, err + } + if pemBlocks.KeyPEMBlock, err = s.readFile(s.keyFile); err != nil { + return nil, err + } + } + if pemBlocks.RootCAsPEMBlock, err = s.readFile(s.rootCAsFile); err != nil { + return nil, err + } + return pemBlocks, nil +} + +func (s *fileSource) readFile(name string) ([]byte, error) { + if name == "" { + return nil, nil + } + return os.ReadFile(name) +} diff --git a/tls/client/filesource/filesource_test.go b/tls/client/filesource/filesource_test.go new file mode 100644 index 0000000..84a57bf --- /dev/null +++ b/tls/client/filesource/filesource_test.go @@ -0,0 +1,84 @@ +package filesource + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/grepplabs/cert-source/internal/testutil" + tlsclient "github.com/grepplabs/cert-source/tls/client" + servertls "github.com/grepplabs/cert-source/tls/server" + serverfilesource "github.com/grepplabs/cert-source/tls/server/filesource" + "github.com/stretchr/testify/require" +) + +func TestCertRotation(t *testing.T) { + bundle1 := testutil.NewCertsBundle() + defer bundle1.Close() + + bundle2 := testutil.NewCertsBundle() + defer bundle2.Close() + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + rotatedCh := make(chan struct{}, 1) + notifyFunc := func() { + rotatedCh <- struct{}{} + } + clientSource := MustNew( + WithClientRootCAs(bundle1.CACert.Name()), + WithClientCert(bundle1.ClientCert.Name(), bundle1.ClientKey.Name()), + WithRefresh(1*time.Second), + WithNotifyFunc(notifyFunc), + ).(*fileSource) + + clientCertsStore, err := tlsclient.NewTLSClientCertsStore(slog.Default(), clientSource) + require.NoError(t, err) + + serverSource := serverfilesource.MustNew( + serverfilesource.WithX509KeyPair(bundle1.ServerCert.Name(), bundle1.ServerKey.Name()), + serverfilesource.WithClientAuthFile(bundle1.CACert.Name()), + serverfilesource.WithClientCRLFile(bundle1.CAEmptyCRL.Name()), + serverfilesource.WithRefresh(1*time.Second), + serverfilesource.WithNotifyFunc(notifyFunc), + ) + ts.TLS = servertls.MustNewServerConfig(slog.Default(), serverSource) + ts.StartTLS() + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + require.NoError(t, err) + + // when + client := &http.Client{ + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)), + } + _, err = client.Do(req) + require.NoError(t, err) + + require.NoError(t, os.Rename(bundle2.ClientCert.Name(), bundle1.ClientCert.Name())) + require.NoError(t, os.Rename(bundle2.ClientKey.Name(), bundle1.ClientKey.Name())) + + select { + case <-rotatedCh: + t.Log("certificates were changed") + time.Sleep(100 * time.Millisecond) + case <-time.After(3 * time.Second): + t.Fatal("expected certificate change notification") + } + + // old client - bad certificate + // create new client as connection can be kept alive + client = &http.Client{ + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)), + } + _, err = client.Do(req) + require.NotNil(t, err) + require.Contains(t, err.Error(), "unknown certificate authority") + +} diff --git a/tls/client/filesource/option.go b/tls/client/filesource/option.go new file mode 100644 index 0000000..1ae39fb --- /dev/null +++ b/tls/client/filesource/option.go @@ -0,0 +1,45 @@ +package filesource + +import ( + "log/slog" + "time" +) + +type Option func(*fileSource) + +func WithLogger(logger *slog.Logger) Option { + return func(c *fileSource) { + c.logger = logger + } +} + +func WithClientCert(certFile, keyFile string) Option { + return func(c *fileSource) { + c.certFile = certFile + c.keyFile = keyFile + } +} + +func WithClientRootCAs(rootCAsFile string) Option { + return func(c *fileSource) { + c.rootCAsFile = rootCAsFile + } +} + +func WithInsecureSkipVerify(insecureSkipVerify bool) Option { + return func(c *fileSource) { + c.insecureSkipVerify = insecureSkipVerify + } +} + +func WithRefresh(refresh time.Duration) Option { + return func(c *fileSource) { + c.refresh = refresh + } +} + +func WithNotifyFunc(notifyFunc func()) Option { + return func(c *fileSource) { + c.notifyFunc = notifyFunc + } +} diff --git a/tls/client/roundtripper.go b/tls/client/roundtripper.go new file mode 100644 index 0000000..bbcb103 --- /dev/null +++ b/tls/client/roundtripper.go @@ -0,0 +1,89 @@ +package client + +import ( + "crypto/tls" + "crypto/x509" + "net/http" + + "github.com/grepplabs/cert-source/tls/client/source" +) + +type RoundTripper struct { + transport *http.Transport +} + +type RoundTripperOption func(*RoundTripper) + +func WithClientTLSConfig(tlsClientConfig *tls.Config) RoundTripperOption { + return func(rt *RoundTripper) { + rt.transport.TLSClientConfig = tlsClientConfig + } +} + +func WithClientCertsStore(source *source.ClientCertsStore) RoundTripperOption { + return func(rt *RoundTripper) { + cs := source.LoadClientCerts() + if rt.transport.TLSClientConfig == nil { + rt.transport.TLSClientConfig = &tls.Config{} + } + rt.transport.TLSClientConfig.RootCAs = cs.RootCAs + rt.transport.TLSClientConfig.InsecureSkipVerify = cs.InsecureSkipVerify + rt.transport.TLSClientConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return source.LoadClientCerts().Certificate, nil + } + } +} + +func WithRootCA(cert *x509.Certificate) RoundTripperOption { + certPool := x509.NewCertPool() + certPool.AddCert(cert) + return WithRootCAs(certPool) +} + +func WithRootCAs(rootCAs *x509.CertPool) RoundTripperOption { + return func(rt *RoundTripper) { + if rt.transport.TLSClientConfig == nil { + rt.transport.TLSClientConfig = &tls.Config{} + } + rt.transport.TLSClientConfig.RootCAs = rootCAs + } +} + +func WithClientTLSSkipVerify(skipVerify bool) RoundTripperOption { + return func(rt *RoundTripper) { + if rt.transport.TLSClientConfig == nil { + rt.transport.TLSClientConfig = &tls.Config{} + } + rt.transport.TLSClientConfig.InsecureSkipVerify = skipVerify + } +} + +func WithClientCertificate(clientCert *tls.Certificate) RoundTripperOption { + return func(rt *RoundTripper) { + if rt.transport.TLSClientConfig == nil { + rt.transport.TLSClientConfig = &tls.Config{} + } + rt.transport.TLSClientConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return clientCert, nil + } + } +} + +func NewRoundTripper(transport *http.Transport, options ...RoundTripperOption) *RoundTripper { + rt := &RoundTripper{ + transport: transport, + } + for _, option := range options { + option(rt) + } + return rt +} + +func NewDefaultRoundTripper(options ...RoundTripperOption) *RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + return NewRoundTripper(transport, options...) +} + +func (p *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return p.transport.RoundTrip(req) +} diff --git a/tls/client/source/pems.go b/tls/client/source/pems.go new file mode 100644 index 0000000..b8d530d --- /dev/null +++ b/tls/client/source/pems.go @@ -0,0 +1,47 @@ +package source + +import ( + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "errors" +) + +type ClientPEMsLoader interface { + Load() (*ClientPEMs, error) +} + +type ClientPEMs struct { + CertPEMBlock []byte + KeyPEMBlock []byte + RootCAsPEMBlock []byte +} + +func (s ClientPEMs) Checksum() []byte { + hash := sha256.New() + hash.Write(s.CertPEMBlock) + hash.Write(s.KeyPEMBlock) + return hash.Sum(s.RootCAsPEMBlock) +} + +func (s ClientPEMs) Certificate() (*tls.Certificate, error) { + if len(s.CertPEMBlock) == 0 || len(s.KeyPEMBlock) == 0 { + return nil, nil + } + cert, err := tls.X509KeyPair(s.CertPEMBlock, s.KeyPEMBlock) + if err != nil { + return nil, err + } + return &cert, nil +} + +func (s ClientPEMs) RootCAs() (*x509.CertPool, error) { + if len(s.RootCAsPEMBlock) == 0 { + return nil, nil + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(s.RootCAsPEMBlock) { + return nil, errors.New("client PEMs: building client CAs failed") + } + return certPool, nil +} diff --git a/tls/client/source/store.go b/tls/client/source/store.go new file mode 100644 index 0000000..91b33b2 --- /dev/null +++ b/tls/client/source/store.go @@ -0,0 +1,69 @@ +package source + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "fmt" + "log/slog" + "sync/atomic" +) + +type ClientCertsSource interface { + ClientCerts() chan ClientCerts +} + +type ClientCerts struct { + InsecureSkipVerify bool + Certificate *tls.Certificate + RootCAs *x509.CertPool + Checksum []byte +} + +func (s ClientCerts) GetChecksum() []byte { + return s.Checksum +} + +type ClientCertsStore struct { + cs atomic.Pointer[ClientCerts] + logger *slog.Logger +} + +func NewClientCertsStore(logger *slog.Logger) *ClientCertsStore { + s := &ClientCertsStore{ + logger: logger, + } + s.cs.Store(&ClientCerts{}) + return s +} + +func (s *ClientCertsStore) LoadClientCerts() ClientCerts { + return *s.cs.Load() +} + +func (s *ClientCertsStore) SetClientCerts(certs ClientCerts) { + s.cs.Store(&certs) + s.logger.Info(fmt.Sprintf("stored x509 client root certs, client cert [%s]", name(certs.Certificate))) +} + +func name(cert *tls.Certificate) string { + if cert == nil { + return "" + } + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return "" + } + return fmt.Sprintf("%s=%s", getHexFormatted(x509Cert.SerialNumber.Bytes(), ":"), x509Cert.Subject.CommonName) +} + +func getHexFormatted(buf []byte, sep string) string { + var ret bytes.Buffer + for _, cur := range buf { + if ret.Len() > 0 { + _, _ = fmt.Fprint(&ret, sep) + } + _, _ = fmt.Fprintf(&ret, "%02x", cur) + } + return ret.String() +} diff --git a/tls/server/config/config.go b/tls/server/config/config.go new file mode 100644 index 0000000..c928387 --- /dev/null +++ b/tls/server/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "crypto/tls" + "fmt" + "log/slog" + + "github.com/grepplabs/cert-source/config" + tlsserver "github.com/grepplabs/cert-source/tls/server" + "github.com/grepplabs/cert-source/tls/server/filesource" +) + +func GetServerTLSConfig(logger *slog.Logger, conf *config.TLSServerConfig) (*tls.Config, error) { + fs, err := filesource.New( + filesource.WithLogger(logger), + filesource.WithX509KeyPair(conf.File.Cert, conf.File.Key), + filesource.WithClientAuthFile(conf.File.ClientCAs), + filesource.WithClientCRLFile(conf.File.ClientCLR), + filesource.WithRefresh(conf.Refresh), + ) + if err != nil { + return nil, fmt.Errorf("setup server cert file source: %w", err) + } + tlsConfig, err := tlsserver.NewServerConfig(logger, fs) + if err != nil { + return nil, fmt.Errorf("setup server TLS config: %w", err) + } + return tlsConfig, nil +} diff --git a/tls/server/config/config_test.go b/tls/server/config/config_test.go new file mode 100644 index 0000000..15e035b --- /dev/null +++ b/tls/server/config/config_test.go @@ -0,0 +1,32 @@ +package config + +import ( + "crypto/tls" + "log/slog" + "testing" + + "github.com/grepplabs/cert-source/config" + "github.com/grepplabs/cert-source/internal/testutil" + "github.com/stretchr/testify/require" +) + +func TestGetServerTLSConfig(t *testing.T) { + bundle := testutil.NewCertsBundle() + defer bundle.Close() + + tlsConfig, err := GetServerTLSConfig(slog.Default(), &config.TLSServerConfig{ + Enable: true, + Refresh: 0, + File: config.TLSServerFiles{ + Key: bundle.ServerKey.Name(), + Cert: bundle.ServerCert.Name(), + ClientCAs: bundle.CACert.Name(), + ClientCLR: bundle.ClientCRL.Name(), + }, + }) + require.NoError(t, err) + require.NotNil(t, tlsConfig.ClientCAs) + require.Equal(t, tlsConfig.ClientAuth, tls.RequireAndVerifyClientCert) + require.NotEmpty(t, tlsConfig.Certificates) + require.NotNil(t, tlsConfig.VerifyPeerCertificate) +} diff --git a/tls/server/filesource/filesource.go b/tls/server/filesource/filesource.go new file mode 100644 index 0000000..bf196a0 --- /dev/null +++ b/tls/server/filesource/filesource.go @@ -0,0 +1,146 @@ +package filesource + +import ( + "errors" + "log/slog" + "os" + "path/filepath" + "sync/atomic" + "time" + + tlscert "github.com/grepplabs/cert-source/tls/server/source" + "github.com/grepplabs/cert-source/tls/watcher" +) + +const ( + defaultCertFile = "server-crt.pem" + defaultKeyFile = "server-key.pem" +) + +type fileSource struct { + certFile string + keyFile string + clientAuthFile string + clientCRLFile string + refresh time.Duration + logger *slog.Logger + notifyFunc func() + lastServerCerts atomic.Pointer[tlscert.ServerCerts] +} + +func New(opts ...Option) (tlscert.ServerCertsSource, error) { + s := &fileSource{ + logger: slog.Default(), + } + if dir, err := os.Getwd(); err == nil { + s.certFile = filepath.Join(dir, defaultCertFile) + s.keyFile = filepath.Join(dir, defaultKeyFile) + } else { + return nil, err + } + for _, opt := range opts { + opt(s) + } + lastServerCerts, err := s.getServerCerts() + if err != nil { + return nil, err + } + s.lastServerCerts.Store(lastServerCerts) + return s, nil +} + +func MustNew(opts ...Option) tlscert.ServerCertsSource { + serverSource, err := New(opts...) + if err != nil { + panic(`filesource: New(): ` + err.Error()) + } + return serverSource +} + +func (s *fileSource) getServerCerts() (*tlscert.ServerCerts, error) { + pemBlocks, err := s.Load() + if err != nil { + return nil, err + } + certificates, err := pemBlocks.Certificates() + if err != nil { + return nil, err + } + clientCAs, err := pemBlocks.ClientCAs() + if err != nil { + return nil, err + } + clientCRLs, err := pemBlocks.ClientCRLs() + if err != nil { + return nil, err + } + if err = pemBlocks.ValidateCRLs(); err != nil { + return nil, err + } + return &tlscert.ServerCerts{ + Certificates: certificates, + ClientCAs: clientCAs, + ClientCRLs: clientCRLs, + RevokedSerialNumbers: tlscert.NewRevokedSerialNumbers(clientCRLs), + Checksum: pemBlocks.Checksum(), + }, nil +} + +func (s *fileSource) refreshServerCerts() (*tlscert.ServerCerts, error) { + serverCerts, err := s.getServerCerts() + if err != nil { + return nil, err + } + s.lastServerCerts.Store(serverCerts) + return serverCerts, nil +} + +func (s *fileSource) ServerCerts() chan tlscert.ServerCerts { + initialServerCert := s.lastServerCerts.Load() + ch := make(chan tlscert.ServerCerts, 1) + if initialServerCert != nil { + ch <- *initialServerCert + } + if s.refresh <= 0 { + close(ch) + } else { + go func() { + watcher.Watch(s.logger, ch, s.refresh, initialServerCert, s.refreshServerCerts, s.notifyFunc) + close(ch) + }() + } + return ch +} + +func (s *fileSource) Load() (pemBlocks *tlscert.ServerPEMs, err error) { + if s.certFile == "" { + return nil, errors.New("cert file source: certFile is required") + } + if s.keyFile == "" { + return nil, errors.New("cert file source: keyFile is required") + } + if s.clientAuthFile == "" && s.clientCRLFile != "" { + return nil, errors.New("cert file source: clientAuthFile is required when clientCRLFile is provided") + } + pemBlocks = &tlscert.ServerPEMs{} + if pemBlocks.CertPEMBlock, err = s.readFile(s.certFile); err != nil { + return nil, err + } + if pemBlocks.KeyPEMBlock, err = s.readFile(s.keyFile); err != nil { + return nil, err + } + if pemBlocks.ClientAuthPEMBlock, err = s.readFile(s.clientAuthFile); err != nil { + return nil, err + } + if pemBlocks.CRLPEMBlock, err = s.readFile(s.clientCRLFile); err != nil { + return nil, err + } + return pemBlocks, nil +} + +func (s *fileSource) readFile(name string) ([]byte, error) { + if name == "" { + return nil, nil + } + return os.ReadFile(name) +} diff --git a/tls/server/filesource/filesource_test.go b/tls/server/filesource/filesource_test.go new file mode 100644 index 0000000..db60c7a --- /dev/null +++ b/tls/server/filesource/filesource_test.go @@ -0,0 +1,214 @@ +package filesource + +import ( + "crypto/tls" + "crypto/x509" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/grepplabs/cert-source/internal/testutil" + tlsclient "github.com/grepplabs/cert-source/tls/client" + servertls "github.com/grepplabs/cert-source/tls/server" + "github.com/stretchr/testify/require" +) + +func TestServerConfig(t *testing.T) { + logger := slog.Default() + bundle := testutil.NewCertsBundle() + defer bundle.Close() + + tests := []struct { + name string + transportFunc func() http.RoundTripper + configFunc func() *tls.Config + requestError bool + }{ + { + name: "Client unknown authority", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper() + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithLogger(logger), + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + )) + }, + requestError: true, + }, + { + name: "Client insecure", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSSkipVerify(true)) + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + )) + }, + }, + { + name: "Client trusted CA", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper(tlsclient.WithRootCA(bundle.CAX509Cert)) + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + )) + }, + }, + { + name: "Client without required certificate", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper(tlsclient.WithRootCA(bundle.CAX509Cert)) + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + WithClientAuthFile(bundle.CACert.Name()), + )) + }, + requestError: true, + }, + { + name: "Client verification success", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper(tlsclient.WithRootCA(bundle.CAX509Cert), tlsclient.WithClientCertificate(bundle.ClientTLSCert)) + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + WithClientAuthFile(bundle.CACert.Name()), + )) + }, + }, + { + name: "Client verification success - empty CRL", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper(tlsclient.WithRootCA(bundle.CAX509Cert), tlsclient.WithClientCertificate(bundle.ClientTLSCert)) + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + WithClientAuthFile(bundle.CACert.Name()), + WithClientCRLFile(bundle.CAEmptyCRL.Name()), + )) + }, + }, + { + name: "Client certificate revoked", + transportFunc: func() http.RoundTripper { + return tlsclient.NewDefaultRoundTripper(tlsclient.WithRootCA(bundle.CAX509Cert), tlsclient.WithClientCertificate(bundle.ClientTLSCert)) + }, + configFunc: func() *tls.Config { + return servertls.MustNewServerConfig(logger, MustNew( + WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()), + WithClientAuthFile(bundle.CACert.Name()), + WithClientCRLFile(bundle.ClientCRL.Name()), + )) + }, + requestError: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // given + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + ts.TLS = tc.configFunc() + ts.StartTLS() + + httpClient := &http.Client{ + Transport: tc.transportFunc(), + } + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + require.NoError(t, err) + + // when + res, err := httpClient.Do(req) + + // then + if tc.requestError { + t.Log(err) + require.NotNil(t, err) + return + } + require.NoError(t, err) + + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + + _ = res.Body.Close() + require.NoError(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) + + }) + } +} + +func TestCertRotation(t *testing.T) { + bundle1 := testutil.NewCertsBundle() + defer bundle1.Close() + + bundle2 := testutil.NewCertsBundle() + defer bundle2.Close() + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + rotatedCh := make(chan struct{}, 1) + notifyFunc := func() { + rotatedCh <- struct{}{} + } + source := MustNew( + WithX509KeyPair(bundle1.ServerCert.Name(), bundle1.ServerKey.Name()), + WithClientAuthFile(bundle1.CACert.Name()), + WithClientCRLFile(bundle1.CAEmptyCRL.Name()), + WithRefresh(1*time.Second), + WithNotifyFunc(notifyFunc), + ).(*fileSource) + + ts.TLS = servertls.MustNewServerConfig(slog.Default(), source) + ts.StartTLS() + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + require.NoError(t, err) + + // when + _, err = bundle1.NewHttpClient().Do(req) + require.NoError(t, err) + + // copy new certificates to be used by server + require.NoError(t, os.Rename(bundle2.ServerCert.Name(), bundle1.ServerCert.Name())) + require.NoError(t, os.Rename(bundle2.ServerKey.Name(), bundle1.ServerKey.Name())) + require.NoError(t, os.Rename(bundle2.CACert.Name(), bundle1.CACert.Name())) + require.NoError(t, os.Rename(bundle2.CAEmptyCRL.Name(), bundle1.CAEmptyCRL.Name())) + + select { + case <-rotatedCh: + t.Log("certificates were changed") + time.Sleep(100 * time.Millisecond) + case <-time.After(3 * time.Second): + t.Fatal("expected certificate change notification") + } + // old client - bad certificate + _, err = bundle1.NewHttpClient().Do(req) + require.NotNil(t, err) + var unknownAuthorityError x509.UnknownAuthorityError + require.ErrorAs(t, err.(*url.Error).Err, &unknownAuthorityError) + + // new client - success + _, err = bundle2.NewHttpClient().Do(req) + require.NoError(t, err) +} diff --git a/tls/server/filesource/option.go b/tls/server/filesource/option.go new file mode 100644 index 0000000..f1bc4c1 --- /dev/null +++ b/tls/server/filesource/option.go @@ -0,0 +1,45 @@ +package filesource + +import ( + "log/slog" + "time" +) + +type Option func(*fileSource) + +func WithLogger(logger *slog.Logger) Option { + return func(c *fileSource) { + c.logger = logger + } +} + +func WithX509KeyPair(certFile, keyFile string) Option { + return func(c *fileSource) { + c.certFile = certFile + c.keyFile = keyFile + } +} + +func WithClientAuthFile(clientAuthFile string) Option { + return func(c *fileSource) { + c.clientAuthFile = clientAuthFile + } +} + +func WithClientCRLFile(clientCRLFile string) Option { + return func(c *fileSource) { + c.clientCRLFile = clientCRLFile + } +} + +func WithRefresh(refresh time.Duration) Option { + return func(c *fileSource) { + c.refresh = refresh + } +} + +func WithNotifyFunc(notifyFunc func()) Option { + return func(c *fileSource) { + c.notifyFunc = notifyFunc + } +} diff --git a/tls/server/server.go b/tls/server/server.go new file mode 100644 index 0000000..a4dc91f --- /dev/null +++ b/tls/server/server.go @@ -0,0 +1,100 @@ +package tlsserver + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/grepplabs/cert-source/tls/certutil" + "github.com/grepplabs/cert-source/tls/server/source" +) + +const ( + initLoadTimeout = 5 * time.Second +) + +// MustNewServerConfig is like NewServerConfig but panics if the config cannot be created. +func MustNewServerConfig(logger *slog.Logger, src source.ServerCertsSource) *tls.Config { + c, err := NewServerConfig(logger, src) + if err != nil { + panic(`tls: NewServerConfig(): ` + err.Error()) + } + return c +} + +// NewServerConfig provides new server TLS configuration. +func NewServerConfig(logger *slog.Logger, src source.ServerCertsSource) (*tls.Config, error) { + store, err := NewServerCertsStore(logger, src) + if err != nil { + return nil, err + } + tlsConfig := tls.Config{ + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + cs := store.LoadServerCerts() + x := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: cs.Certificates, + } + if cs.ClientCAs != nil { + x.ClientCAs = cs.ClientCAs + x.ClientAuth = tls.RequireAndVerifyClientCert + x.VerifyPeerCertificate = verifyClientCertificate(logger, store) + } + return x, nil + }, + } + // ignored as GetConfigForClient is used. it is only required to invoke http.ListenAndServeTLS("", "") + cs := store.LoadServerCerts() + tlsConfig.Certificates = cs.Certificates + if cs.ClientCAs != nil { + tlsConfig.ClientCAs = cs.ClientCAs + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.VerifyPeerCertificate = verifyClientCertificate(logger, store) + } + return &tlsConfig, nil +} + +func NewServerCertsStore(logger *slog.Logger, src source.ServerCertsSource) (*source.ServerCertsStore, error) { + store := source.NewServerCertsStore(logger) + logger.Info("initial server certs loading") + + certsChan := src.ServerCerts() + + select { + case certs := <-certsChan: + store.SetServerCerts(certs) + case <-time.After(initLoadTimeout): + return nil, errors.New("get server certs timeout") + } + + go func() { + for certs := range certsChan { + store.SetServerCerts(certs) + } + }() + return store, nil +} + +func verifyClientCertificate(logger *slog.Logger, store *source.ServerCertsStore) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + cs := store.LoadServerCerts() + if len(cs.ClientCRLs) == 0 { + return nil + } + for _, chain := range verifiedChains { + for _, cert := range chain { + if !cert.IsCA { + if cs.IsClientCertRevoked(cert.SerialNumber) { + err := fmt.Errorf("client certificte %s was revoked", certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":")) + logger.Debug(err.Error()) + return err + } + } + } + } + return nil + } +} diff --git a/tls/server/source/pems.go b/tls/server/source/pems.go new file mode 100644 index 0000000..38777d0 --- /dev/null +++ b/tls/server/source/pems.go @@ -0,0 +1,86 @@ +package source + +import ( + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "errors" + + "github.com/grepplabs/cert-source/tls/certutil" +) + +type ServerPEMsLoader interface { + Load() (*ServerPEMs, error) +} + +type ServerPEMs struct { + CertPEMBlock []byte + KeyPEMBlock []byte + ClientAuthPEMBlock []byte + CRLPEMBlock []byte +} + +func (s ServerPEMs) Checksum() []byte { + hash := sha256.New() + hash.Write(s.CertPEMBlock) + hash.Write(s.KeyPEMBlock) + hash.Write(s.ClientAuthPEMBlock) + return hash.Sum(s.CRLPEMBlock) +} + +func (s ServerPEMs) Certificates() ([]tls.Certificate, error) { + cert, err := tls.X509KeyPair(s.CertPEMBlock, s.KeyPEMBlock) + if err != nil { + return nil, err + } + return []tls.Certificate{cert}, nil +} + +func (s ServerPEMs) ClientCAs() (*x509.CertPool, error) { + if len(s.ClientAuthPEMBlock) == 0 { + return nil, nil + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(s.ClientAuthPEMBlock) { + return nil, errors.New("server PEMs: building client CAs failed") + } + return certPool, nil +} + +func (s ServerPEMs) ClientCRLs() ([]*x509.RevocationList, error) { + if len(s.CRLPEMBlock) == 0 { + return nil, nil + } + return certutil.ParseCRLsPEM(s.CRLPEMBlock) +} + +func (s ServerPEMs) ValidateCRLs() error { + if len(s.ClientAuthPEMBlock) == 0 { + return nil + } + clientCRLs, err := s.ClientCRLs() + if err != nil { + return err + } + if len(clientCRLs) == 0 { + return nil + } + certs, err := certutil.ParseCertsPEM(s.ClientAuthPEMBlock) + if err != nil { + return err + } + for _, clientCRL := range clientCRLs { + ok := false + for _, cert := range certs { + err := clientCRL.CheckSignatureFrom(cert) + if err == nil { + ok = true + continue + } + } + if !ok { + return errors.New("server PEMs: CLR validation failure") + } + } + return nil +} diff --git a/tls/server/source/store.go b/tls/server/source/store.go new file mode 100644 index 0000000..1b2e9bd --- /dev/null +++ b/tls/server/source/store.go @@ -0,0 +1,86 @@ +package source + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "log/slog" + "math/big" + "strings" + "sync/atomic" + + "github.com/grepplabs/cert-source/tls/certutil" +) + +type ServerCertsSource interface { + ServerCerts() chan ServerCerts +} + +type ServerCerts struct { + Certificates []tls.Certificate + ClientCAs *x509.CertPool + ClientCRLs []*x509.RevocationList + Checksum []byte + RevokedSerialNumbers map[string]struct{} +} + +func (s *ServerCerts) GetChecksum() []byte { + return s.Checksum +} + +func NewRevokedSerialNumbers(clientCRLs []*x509.RevocationList) map[string]struct{} { + revokedSerialNumbers := make(map[string]struct{}) + for _, clientCRL := range clientCRLs { + for _, revoked := range clientCRL.RevokedCertificateEntries { + revokedSerialNumbers[string(revoked.SerialNumber.Bytes())] = struct{}{} + } + } + return revokedSerialNumbers +} + +func (s *ServerCerts) IsClientCertRevoked(serialNumber *big.Int) bool { + _, ok := s.RevokedSerialNumbers[string(serialNumber.Bytes())] + return ok +} + +type ServerCertsStore struct { + cs atomic.Pointer[ServerCerts] + logger *slog.Logger +} + +func NewServerCertsStore(logger *slog.Logger) *ServerCertsStore { + s := &ServerCertsStore{ + logger: logger, + } + s.cs.Store(&ServerCerts{}) + return s +} + +func (s *ServerCertsStore) LoadServerCerts() ServerCerts { + return *s.cs.Load() +} + +func (s *ServerCertsStore) SetServerCerts(certs ServerCerts) { + s.cs.Store(&certs) + s.logger.Info(fmt.Sprintf("stored x509 server certs for names [%s]", names(certs.Certificates))) +} + +func names(certs []tls.Certificate) []string { + var result []string + for _, c := range certs { + x509Cert, err := x509.ParseCertificate(c.Certificate[0]) + if err != nil { + continue + } + var names []string + if len(x509Cert.Subject.CommonName) > 0 { + names = append(names, x509Cert.Subject.CommonName) + } + names = append(names, x509Cert.DNSNames...) + for _, ip := range x509Cert.IPAddresses { + names = append(names, ip.String()) + } + result = append(result, fmt.Sprintf("%s=%s", certutil.GetHexFormatted(x509Cert.SerialNumber.Bytes(), ":"), strings.Join(names, ","))) + } + return result +} diff --git a/tls/watcher/watch.go b/tls/watcher/watch.go new file mode 100644 index 0000000..b1e5250 --- /dev/null +++ b/tls/watcher/watch.go @@ -0,0 +1,53 @@ +package watcher + +import ( + "fmt" + "log/slog" + "reflect" + "time" +) + +func Watch[T any, PT interface { + GetChecksum() []byte + *T +}](logger *slog.Logger, ch chan T, refresh time.Duration, init PT, loadFn func() (PT, error), changedFn func()) { + once := refresh <= 0 + + if refresh < time.Second { + refresh = time.Second + } + logger.Info(fmt.Sprintf("cert watch is started, refresh interval %s", refresh)) + + var last = init + for { + next, err := loadFn() + if err != nil { + logger.Error("cannot load certificates", slog.String("error", err.Error())) + time.Sleep(refresh) + continue + } + if last != nil { + if reflect.DeepEqual(next.GetChecksum(), last.GetChecksum()) { + if once && init != nil { + // init value is set, so assume it was already sent to channel + logger.Info("cert watch is disabled") + return + } + time.Sleep(refresh) + continue + } + } + + ch <- *next + last = next + + if changedFn != nil { + changedFn() + } + if once { + logger.Info("cert watch is disabled") + return + } + time.Sleep(refresh) + } +}