diff --git a/certificate.go b/certificate.go index d1cb792..4dd90ed 100644 --- a/certificate.go +++ b/certificate.go @@ -104,6 +104,7 @@ const ( ) // TLSCertificate returns the Certificate as tls.Certificate. +// Complete certificate chain (up to but not including root) is included for end-entity certificates. // A key pair and certificate will be generated at first call of any Certificate functions. // Error is not nil if generation fails. func (c *Certificate) TLSCertificate() (tls.Certificate, error) { @@ -149,6 +150,7 @@ func (c *Certificate) PrivateKey() (crypto.Signer, error) { } // PEM returns the Certificate as certificate and private key PEM buffers. +// Complete certificate chain (up to but not including root) is included for end-entity certificates. // A key pair and certificate will be generated at first call of any Certificate functions. // Error is not nil if generation fails. func (c *Certificate) PEM() (cert []byte, key []byte, err error) { @@ -157,38 +159,28 @@ func (c *Certificate) PEM() (cert []byte, key []byte, err error) { return } - var buf bytes.Buffer - - err = pem.Encode(&buf, &pem.Block{ - Type: "CERTIFICATE", - Bytes: c.GeneratedCert.Certificate[0], - }) + cert, err = encodeToPEMBlocks("CERTIFICATE", c.GeneratedCert.Certificate) if err != nil { return } - cert = append(cert, buf.Bytes()...) // Create copy of underlying buf. - - buf.Reset() k, err := x509.MarshalPKCS8PrivateKey(c.GeneratedCert.PrivateKey) if err != nil { cert = nil return } - err = pem.Encode(&buf, &pem.Block{ - Type: "PRIVATE KEY", - Bytes: k, - }) + + key, err = encodeToPEMBlocks("PRIVATE KEY", [][]byte{k}) if err != nil { cert = nil return } - key = append(key, buf.Bytes()...) // Create copy of underlying buf. return } // WritePEM writes the Certificate as certificate and private key PEM files. +// Complete certificate chain (up to but not including root) is included for end-entity certificates. // A key pair and certificate will be generated at first call of any Certificate functions. // Error is not nil if generation fails. func (c *Certificate) WritePEM(certFile, keyFile string) error { @@ -357,6 +349,7 @@ func (c *Certificate) Generate() error { var issuerCert *x509.Certificate var issuerKey crypto.Signer + var chain [][]byte if c.Issuer != nil { issuerCert, err = x509.ParseCertificate(c.Issuer.GeneratedCert.Certificate[0]) if err != nil { @@ -364,6 +357,17 @@ func (c *Certificate) Generate() error { } issuerKey = c.Issuer.GeneratedCert.PrivateKey.(crypto.Signer) + // Add certificate chain to end-entity certificates. + if !*c.IsCA { + issuer := c.Issuer + for issuer != nil { + // Add issuer to chain unless it is root certificate. + if issuer.Issuer != nil { + chain = append(chain, issuer.GeneratedCert.Certificate[0]) + } + issuer = issuer.Issuer + } + } } else { // create self-signed certificate issuerCert = template @@ -377,9 +381,26 @@ func (c *Certificate) Generate() error { } c.GeneratedCert = &tls.Certificate{ - Certificate: [][]byte{cert}, + Certificate: append([][]byte{cert}, chain...), PrivateKey: key, } return nil } + +func encodeToPEMBlocks(blockType string, blocks [][]byte) ([]byte, error) { + var buf bytes.Buffer + + for _, b := range blocks { + err := pem.Encode(&buf, &pem.Block{ + Type: blockType, + Bytes: b, + }) + if err != nil { + return nil, err + } + + } + + return buf.Bytes(), nil +} diff --git a/certificate_test.go b/certificate_test.go index af544c8..a937b74 100644 --- a/certificate_test.go +++ b/certificate_test.go @@ -349,3 +349,56 @@ func TestSerial(t *testing.T) { assert.Nil(t, err) assert.NotEqual(t, got1.SerialNumber, got2.SerialNumber) } + +func TestCertificateChain(t *testing.T) { + isCA := true + rootCA := Certificate{Subject: "CN=ca"} + subCA1 := Certificate{Subject: "CN=sub-ca-1", Issuer: &rootCA, IsCA: &isCA} + subCA2 := Certificate{Subject: "CN=sub-ca-2", Issuer: &subCA1, IsCA: &isCA} + endEntity := Certificate{Subject: "CN=end-entity", Issuer: &subCA2} + + // End-entity certificates have certificate chains appended. + got, err := endEntity.TLSCertificate() + assert.Nil(t, err) + assert.Equal(t, 3, len(got.Certificate)) + assert.Equal(t, endEntity.GeneratedCert.Certificate[0], got.Certificate[0]) + assert.Equal(t, subCA2.GeneratedCert.Certificate[0], got.Certificate[1]) + assert.Equal(t, subCA1.GeneratedCert.Certificate[0], got.Certificate[2]) + + // CA certificates do not have chains appended. + got, err = subCA2.TLSCertificate() + assert.Nil(t, err) + assert.Equal(t, 1, len(got.Certificate)) +} + +func TestCertificateChainInPEM(t *testing.T) { + isCA := true + rootCA := Certificate{Subject: "CN=ca"} + subCA1 := Certificate{Subject: "CN=sub-ca-1", Issuer: &rootCA, IsCA: &isCA} + subCA2 := Certificate{Subject: "CN=sub-ca-2", Issuer: &subCA1, IsCA: &isCA} + endEntity := Certificate{Subject: "CN=end-entity", Issuer: &subCA2} + + // End-entity certificates have certificate chains appended. + got, _, err := endEntity.PEM() + assert.Nil(t, err) + + block, rest := pem.Decode(got) + assert.NotNil(t, block) + cert, err := x509.ParseCertificate(block.Bytes) + assert.Nil(t, err) + assert.Equal(t, "CN=end-entity", cert.Subject.String()) + + block, rest = pem.Decode(rest) + assert.NotNil(t, block) + cert, err = x509.ParseCertificate(block.Bytes) + assert.Nil(t, err) + assert.Equal(t, "CN=sub-ca-2", cert.Subject.String()) + + block, rest = pem.Decode(rest) + assert.NotNil(t, block) + cert, err = x509.ParseCertificate(block.Bytes) + assert.Nil(t, err) + assert.Equal(t, "CN=sub-ca-1", cert.Subject.String()) + + assert.Empty(t, rest) +}