Skip to content

Commit

Permalink
Update system store (GoogleCloudPlatform#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaitanyaKulkarni28 committed Sep 18, 2023
1 parent 098c008 commit 1c253f6
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 6 deletions.
2 changes: 1 addition & 1 deletion google_guest_agent/agentcrypto/mtls_mds.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (j *CredsJob) Run(ctx context.Context) (bool, error) {
return true, fmt.Errorf("failed to read Root CA cert with an error: %w", err)
}

if err := j.writeRootCACert(v.Content, filepath.Join(defaultCredsDir, rootCACertFileName)); err != nil {
if err := j.writeRootCACert(ctx, v.Content, filepath.Join(defaultCredsDir, rootCACertFileName)); err != nil {
return true, fmt.Errorf("failed to store Root CA cert with an error: %w", err)
}

Expand Down
83 changes: 80 additions & 3 deletions google_guest_agent/agentcrypto/mtls_mds_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

package agentcrypto

import "github.com/GoogleCloudPlatform/guest-agent/utils"
import (
"context"
"fmt"
"os/exec"
"path/filepath"

"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run"
"github.com/GoogleCloudPlatform/guest-agent/utils"
"github.com/GoogleCloudPlatform/guest-logging-go/logger"
)

const (
// defaultCredsDir is the directory location for MTLS MDS credentials.
Expand All @@ -27,11 +36,79 @@ const (
)

// writeRootCACert writes Root CA cert from UEFI variable to output file.
func (j *CredsJob) writeRootCACert(content []byte, outputFile string) error {
return utils.SaferWriteFile(content, outputFile)
func (j *CredsJob) writeRootCACert(ctx context.Context, content []byte, outputFile string) error {
if err := utils.SaferWriteFile(content, outputFile); err != nil {
return err
}

// Best effort to update system store, don't fail.
if err := updateSystemStore(ctx, outputFile); err != nil {
logger.Errorf("Failed add Root MDS cert to system trust store with error: %v", err)
}

return nil
}

// writeClientCredentials stores client credentials (certificate and private key).
func (j *CredsJob) writeClientCredentials(plaintext []byte, outputFile string) error {
return utils.SaferWriteFile(plaintext, outputFile)
}

// getCAStoreUpdater interates over known system trust store updaters and returns the first found.
func getCAStoreUpdater() (string, error) {
knownUpdaters := []string{"update-ca-certificates", "update-ca-trust"}
var errs []string

for _, u := range knownUpdaters {
_, err := exec.LookPath(u)
if err == nil {
return u, nil
}
errs = append(errs, err.Error())
}

return "", fmt.Errorf("no known trust updaters %v were found: %v", knownUpdaters, errs)
}

// certificateDirFromUpdater returns directory of local CA certificates for the given updater tool.
func certificateDirFromUpdater(updater string) (string, error) {
switch updater {
// SUSE, Debian and Ubuntu distributions.
// https://manpages.ubuntu.com/manpages/xenial/man8/update-ca-certificates.8.html
case "update-ca-certificates":
return "/usr/local/share/ca-certificates/", nil
// CentOS, Fedora, RedHat distributions.
// https://www.unix.com/man-page/centos/8/UPDATE-CA-TRUST/
case "update-ca-trust":
return "/etc/pki/ca-trust/source/anchors/", nil
default:
return "", fmt.Errorf("unknown updater %q, no local trusted CA certificate directory found", updater)
}
}

// updateSystemStore updates the local system store with the cert.
func updateSystemStore(ctx context.Context, cert string) error {
cmd, err := getCAStoreUpdater()
if err != nil {
return err
}

dir, err := certificateDirFromUpdater(cmd)
if err != nil {
return err
}

dest := filepath.Join(dir, filepath.Base(cert))

if err := utils.CopyFile(cert, dest); err != nil {
return err
}

res := run.WithOutput(ctx, cmd)
if res.ExitCode != 0 {
return fmt.Errorf("command %q failed with error: %s", cmd, res.Error())
}

logger.Infof("Certificate %q added to system store successfully %s", cert, res.StdOut)
return nil
}
37 changes: 36 additions & 1 deletion google_guest_agent/agentcrypto/mtls_mds_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestReadAndWriteRootCACert(t *testing.T) {
t.Errorf("readRootCACert(%+v) failed unexpectedly with error: %v", v, err)
}

if err := j.writeRootCACert(ca.Content, crt); err != nil {
if err := j.writeRootCACert(context.Background(), ca.Content, crt); err != nil {
t.Errorf("writeRootCACert(%s, %s) failed unexpectedly with error: %v", string(ca.Content), crt, err)
}

Expand Down Expand Up @@ -130,3 +130,38 @@ func TestShouldEnableError(t *testing.T) {
t.Error("ShouldEnable(ctx) = true, want false")
}
}

func TestCertificateDirFromUpdater(t *testing.T) {
tests := []struct {
updater string
want string
}{
{
updater: "update-ca-certificates",
want: "/usr/local/share/ca-certificates/",
},
{
updater: "update-ca-trust",
want: "/etc/pki/ca-trust/source/anchors/",
},
}

for _, test := range tests {
t.Run(test.updater, func(t *testing.T) {
got, err := certificateDirFromUpdater(test.updater)
if err != nil {
t.Errorf("certificateDirFromUpdater(%s) failed unexpectedly with error: %v", test.updater, err)
}
if got != test.want {
t.Errorf("certificateDirFromUpdater(%s) = %s, want %s", test.updater, got, test.want)
}
})
}
}

func TestCertificateDirFromUpdaterError(t *testing.T) {
_, err := certificateDirFromUpdater("unknown")
if err == nil {
t.Errorf("certificateDirFromUpdater(unknown) succeeded for unknown updater, want error")
}
}
3 changes: 2 additions & 1 deletion google_guest_agent/agentcrypto/mtls_mds_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package agentcrypto

import (
"context"
"crypto/rand"
"crypto/x509"
"encoding/pem"
Expand Down Expand Up @@ -57,7 +58,7 @@ var (
)

// writeRootCACert writes Root CA cert from UEFI variable to output file.
func (j *CredsJob) writeRootCACert(cacert []byte, outputFile string) error {
func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile string) error {
if err := utils.SaferWriteFile(cacert, outputFile); err != nil {
return err
}
Expand Down
10 changes: 10 additions & 0 deletions utils/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,13 @@ func SaferWriteFile(content []byte, outputFile string) error {

return os.Rename(tmp.Name(), outputFile)
}

// CopyFile copies content from src to dst.
func CopyFile(src, dst string) error {
b, err := os.ReadFile(src)
if err != nil {
return fmt.Errorf("failed to read %q: %w", src, err)
}

return WriteFile(b, dst)
}
31 changes: 31 additions & 0 deletions utils/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,34 @@ func TestSaferWriteFile(t *testing.T) {
t.Errorf("os.ReadFile(%s) = %s, want %s", f, string(got), want)
}
}

func TestCopyFile(t *testing.T) {
tmp := t.TempDir()
dst := filepath.Join(tmp, "dst")
src := filepath.Join(tmp, "src")
want := "testdata"
if err := os.WriteFile(src, []byte(want), 0777); err != nil {
t.Fatalf("failed to write test source file: %v", err)
}
if err := CopyFile(src, dst); err != nil {
t.Errorf("CopyFile(%s, %s) failed unexpectedly with error: %v", src, dst, err)
}

got, err := os.ReadFile(dst)
if err != nil {
t.Errorf("unable to read %q: %v", dst, err)
}
if string(got) != want {
t.Errorf("CopyFile(%s, %s) copied %q, expected %q", src, dst, string(got), want)
}
}

func TestCopyFileError(t *testing.T) {
tmp := t.TempDir()
dst := filepath.Join(tmp, "dst")
src := filepath.Join(tmp, "src")

if err := CopyFile(src, dst); err == nil {
t.Errorf("CopyFile(%s, %s) succeeded for non-existent file, want error", src, dst)
}
}

0 comments on commit 1c253f6

Please sign in to comment.