Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support auto tls #115

Merged
merged 9 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ package main

import (
"context"
"crypto/ecdsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -60,8 +63,9 @@ import (
var version = "0.0.0-dev"

const (
slashSeparator = "/"
healthPath = "/v1/health"
slashSeparator = "/"
healthPath = "/v1/health"
certificatesPath = "/v1/certificates"
)

var (
Expand All @@ -76,6 +80,7 @@ var (
globalConnStats atomic.Pointer[[]*ConnStats]
log2 *logrus.Logger
globalHostBalance string
globalTLSCert atomic.Pointer[[]byte]
)

const (
Expand Down Expand Up @@ -516,15 +521,26 @@ func (m *multisite) populate() {
}

func (m *multisite) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && r.URL.Path == certificatesPath {
cert := globalTLSCert.Load()
if cert != nil {
w.Write(*cert)
} else {
http.Error(w, "no configured certificates found", http.StatusNotFound)
}
return
}
w.Header().Set("Server", "SideKick") // indicate sidekick is serving
for _, s := range *m.sites.Load() {
if s.Online() {
if r.URL.Path == healthPath {
switch r.URL.Path {
case healthPath:
// Health check endpoint should return success
return
default:
s.ServeHTTP(w, r)
return
}
s.ServeHTTP(w, r)
return
}
}
writeErrorResponse(w, r, errors.New("all backend servers are offline"))
Expand Down Expand Up @@ -1086,6 +1102,43 @@ func sidekickMain(ctx *cli.Context) {
ClientSessionCache: tls.NewLRUClientSessionCache(tlsClientSessionCacheSize),
}
server.TLSConfig = tlsConfig
} else if ctx.String("auto-tls-host") != "" {
cert, key, err := generateTLSCertKey(ctx.String("auto-tls-host"))
if err != nil {
console.Fatalln(err)
}
console.Printf("Generated TLS certificate for host '%s'\n", ctx.String("auto-tls-host"))
certificates, err := tls.X509KeyPair(cert, key)
if err != nil {
console.Fatalln(err)
}
fingerprint := sha256.Sum256(certificates.Certificate[0])
console.Printf("\nCertificate: % X", fingerprint[:len(fingerprint)/2])
console.Printf("\n % X", fingerprint[len(fingerprint)/2:])
var publicKeyDER []byte
switch privateKey := certificates.PrivateKey.(type) {
case *ecdsa.PrivateKey:
publicKeyDER, err = x509.MarshalPKIXPublicKey(privateKey.Public())
default:
console.Fatalln(fmt.Errorf("unsupported private key type %T", privateKey))
}
if err != nil {
console.Fatalln(err)
}
publicKey := sha256.Sum256(publicKeyDER)
console.Println("\nPublic Key: " + base64.StdEncoding.EncodeToString(publicKey[:]))
console.Println()
globalTLSCert.Store(&cert)

tlsConfig := &tls.Config{
PreferServerCipherSuites: true,
NextProtos: []string{"http/1.1", "h2"},
Certificates: []tls.Certificate{certificates},
MinVersion: tls.VersionTLS12,
MaxVersion: tlsMaxVersion,
ClientSessionCache: tls.NewLRUClientSessionCache(tlsClientSessionCacheSize),
}
server.TLSConfig = tlsConfig
}
go func() {
if err := server.ListenAndServe(); err != nil {
Expand Down Expand Up @@ -1163,6 +1216,10 @@ func main() {
Name: "rr-dns-mode",
Usage: "enable round-robin DNS mode",
},
cli.StringFlag{
Name: "auto-tls-host",
Usage: "enable auto TLS mode for the specified host",
},
cli.BoolFlag{
Name: "log, l",
Usage: "enable logging",
Expand Down
124 changes: 124 additions & 0 deletions tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) 2021-2024 MinIO, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package main

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
crand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"strings"
"time"
)

// generateTLSCertKey creates valid key/cert with registered DNS or IP address
// depending on the passed parameter. That way, we can use tls config without
// passing InsecureSkipVerify flag. This code is a simplified version of
// https://golang.org/src/crypto/tls/generate_cert.go
func generateTLSCertKey(host string) ([]byte, []byte, error) {
validFor := 365 * 24 * time.Hour
if len(host) == 0 {
return nil, nil, fmt.Errorf("Missing host parameter")
}

publicKey := func(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}

jiuker marked this conversation as resolved.
Show resolved Hide resolved
pemBlockForKey := func(priv interface{}) *pem.Block {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
case *ecdsa.PrivateKey:
b, err := x509.MarshalECPrivateKey(k)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err)
os.Exit(2)
}
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
default:
return nil
}
}

var priv interface{}
var err error
priv, err = ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}
notBefore := time.Now()
notAfter := notBefore.Add(validFor)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := crand.Int(crand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,

KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

hosts := strings.Split(host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}

template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign

derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
return nil, nil, fmt.Errorf("Failed to create certificate: %w", err)
}

certOut := bytes.NewBuffer([]byte{})
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})

keyOut := bytes.NewBuffer([]byte{})
pem.Encode(keyOut, pemBlockForKey(priv))

return certOut.Bytes(), keyOut.Bytes(), nil
}
Loading