From 1c253f6c196c305d03cc544d1f05e199cc854626 Mon Sep 17 00:00:00 2001 From: Chaitanya Kulkarni Date: Mon, 18 Sep 2023 14:22:35 -0700 Subject: [PATCH] Update system store (#278) --- google_guest_agent/agentcrypto/mtls_mds.go | 2 +- .../agentcrypto/mtls_mds_linux.go | 83 ++++++++++++++++++- .../agentcrypto/mtls_mds_linux_test.go | 37 ++++++++- .../agentcrypto/mtls_mds_windows.go | 3 +- utils/main.go | 10 +++ utils/main_test.go | 31 +++++++ 6 files changed, 160 insertions(+), 6 deletions(-) diff --git a/google_guest_agent/agentcrypto/mtls_mds.go b/google_guest_agent/agentcrypto/mtls_mds.go index c56dc36d8e..cae1be1d03 100644 --- a/google_guest_agent/agentcrypto/mtls_mds.go +++ b/google_guest_agent/agentcrypto/mtls_mds.go @@ -167,7 +167,7 @@ func (j *CredsJob) Run(ctx context.Context) (bool, error) { return true, fmt.Errorf("failed to read Root CA cert with an error: %w", err) } - if err := j.writeRootCACert(v.Content, filepath.Join(defaultCredsDir, rootCACertFileName)); err != nil { + if err := j.writeRootCACert(ctx, v.Content, filepath.Join(defaultCredsDir, rootCACertFileName)); err != nil { return true, fmt.Errorf("failed to store Root CA cert with an error: %w", err) } diff --git a/google_guest_agent/agentcrypto/mtls_mds_linux.go b/google_guest_agent/agentcrypto/mtls_mds_linux.go index 45fd6cd06e..e2b62360a5 100644 --- a/google_guest_agent/agentcrypto/mtls_mds_linux.go +++ b/google_guest_agent/agentcrypto/mtls_mds_linux.go @@ -14,7 +14,16 @@ package agentcrypto -import "github.com/GoogleCloudPlatform/guest-agent/utils" +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + + "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run" + "github.com/GoogleCloudPlatform/guest-agent/utils" + "github.com/GoogleCloudPlatform/guest-logging-go/logger" +) const ( // defaultCredsDir is the directory location for MTLS MDS credentials. @@ -27,11 +36,79 @@ const ( ) // writeRootCACert writes Root CA cert from UEFI variable to output file. -func (j *CredsJob) writeRootCACert(content []byte, outputFile string) error { - return utils.SaferWriteFile(content, outputFile) +func (j *CredsJob) writeRootCACert(ctx context.Context, content []byte, outputFile string) error { + if err := utils.SaferWriteFile(content, outputFile); err != nil { + return err + } + + // Best effort to update system store, don't fail. + if err := updateSystemStore(ctx, outputFile); err != nil { + logger.Errorf("Failed add Root MDS cert to system trust store with error: %v", err) + } + + return nil } // writeClientCredentials stores client credentials (certificate and private key). func (j *CredsJob) writeClientCredentials(plaintext []byte, outputFile string) error { return utils.SaferWriteFile(plaintext, outputFile) } + +// getCAStoreUpdater interates over known system trust store updaters and returns the first found. +func getCAStoreUpdater() (string, error) { + knownUpdaters := []string{"update-ca-certificates", "update-ca-trust"} + var errs []string + + for _, u := range knownUpdaters { + _, err := exec.LookPath(u) + if err == nil { + return u, nil + } + errs = append(errs, err.Error()) + } + + return "", fmt.Errorf("no known trust updaters %v were found: %v", knownUpdaters, errs) +} + +// certificateDirFromUpdater returns directory of local CA certificates for the given updater tool. +func certificateDirFromUpdater(updater string) (string, error) { + switch updater { + // SUSE, Debian and Ubuntu distributions. + // https://manpages.ubuntu.com/manpages/xenial/man8/update-ca-certificates.8.html + case "update-ca-certificates": + return "/usr/local/share/ca-certificates/", nil + // CentOS, Fedora, RedHat distributions. + // https://www.unix.com/man-page/centos/8/UPDATE-CA-TRUST/ + case "update-ca-trust": + return "/etc/pki/ca-trust/source/anchors/", nil + default: + return "", fmt.Errorf("unknown updater %q, no local trusted CA certificate directory found", updater) + } +} + +// updateSystemStore updates the local system store with the cert. +func updateSystemStore(ctx context.Context, cert string) error { + cmd, err := getCAStoreUpdater() + if err != nil { + return err + } + + dir, err := certificateDirFromUpdater(cmd) + if err != nil { + return err + } + + dest := filepath.Join(dir, filepath.Base(cert)) + + if err := utils.CopyFile(cert, dest); err != nil { + return err + } + + res := run.WithOutput(ctx, cmd) + if res.ExitCode != 0 { + return fmt.Errorf("command %q failed with error: %s", cmd, res.Error()) + } + + logger.Infof("Certificate %q added to system store successfully %s", cert, res.StdOut) + return nil +} diff --git a/google_guest_agent/agentcrypto/mtls_mds_linux_test.go b/google_guest_agent/agentcrypto/mtls_mds_linux_test.go index a158eece3d..1a03b722f0 100644 --- a/google_guest_agent/agentcrypto/mtls_mds_linux_test.go +++ b/google_guest_agent/agentcrypto/mtls_mds_linux_test.go @@ -44,7 +44,7 @@ func TestReadAndWriteRootCACert(t *testing.T) { t.Errorf("readRootCACert(%+v) failed unexpectedly with error: %v", v, err) } - if err := j.writeRootCACert(ca.Content, crt); err != nil { + if err := j.writeRootCACert(context.Background(), ca.Content, crt); err != nil { t.Errorf("writeRootCACert(%s, %s) failed unexpectedly with error: %v", string(ca.Content), crt, err) } @@ -130,3 +130,38 @@ func TestShouldEnableError(t *testing.T) { t.Error("ShouldEnable(ctx) = true, want false") } } + +func TestCertificateDirFromUpdater(t *testing.T) { + tests := []struct { + updater string + want string + }{ + { + updater: "update-ca-certificates", + want: "/usr/local/share/ca-certificates/", + }, + { + updater: "update-ca-trust", + want: "/etc/pki/ca-trust/source/anchors/", + }, + } + + for _, test := range tests { + t.Run(test.updater, func(t *testing.T) { + got, err := certificateDirFromUpdater(test.updater) + if err != nil { + t.Errorf("certificateDirFromUpdater(%s) failed unexpectedly with error: %v", test.updater, err) + } + if got != test.want { + t.Errorf("certificateDirFromUpdater(%s) = %s, want %s", test.updater, got, test.want) + } + }) + } +} + +func TestCertificateDirFromUpdaterError(t *testing.T) { + _, err := certificateDirFromUpdater("unknown") + if err == nil { + t.Errorf("certificateDirFromUpdater(unknown) succeeded for unknown updater, want error") + } +} diff --git a/google_guest_agent/agentcrypto/mtls_mds_windows.go b/google_guest_agent/agentcrypto/mtls_mds_windows.go index f2fc3a3f13..8201479a28 100644 --- a/google_guest_agent/agentcrypto/mtls_mds_windows.go +++ b/google_guest_agent/agentcrypto/mtls_mds_windows.go @@ -15,6 +15,7 @@ package agentcrypto import ( + "context" "crypto/rand" "crypto/x509" "encoding/pem" @@ -57,7 +58,7 @@ var ( ) // writeRootCACert writes Root CA cert from UEFI variable to output file. -func (j *CredsJob) writeRootCACert(cacert []byte, outputFile string) error { +func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile string) error { if err := utils.SaferWriteFile(cacert, outputFile); err != nil { return err } diff --git a/utils/main.go b/utils/main.go index 35608da79b..9e63444c73 100644 --- a/utils/main.go +++ b/utils/main.go @@ -197,3 +197,13 @@ func SaferWriteFile(content []byte, outputFile string) error { return os.Rename(tmp.Name(), outputFile) } + +// CopyFile copies content from src to dst. +func CopyFile(src, dst string) error { + b, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("failed to read %q: %w", src, err) + } + + return WriteFile(b, dst) +} diff --git a/utils/main_test.go b/utils/main_test.go index 270b375e18..db4fce3efc 100644 --- a/utils/main_test.go +++ b/utils/main_test.go @@ -149,3 +149,34 @@ func TestSaferWriteFile(t *testing.T) { t.Errorf("os.ReadFile(%s) = %s, want %s", f, string(got), want) } } + +func TestCopyFile(t *testing.T) { + tmp := t.TempDir() + dst := filepath.Join(tmp, "dst") + src := filepath.Join(tmp, "src") + want := "testdata" + if err := os.WriteFile(src, []byte(want), 0777); err != nil { + t.Fatalf("failed to write test source file: %v", err) + } + if err := CopyFile(src, dst); err != nil { + t.Errorf("CopyFile(%s, %s) failed unexpectedly with error: %v", src, dst, err) + } + + got, err := os.ReadFile(dst) + if err != nil { + t.Errorf("unable to read %q: %v", dst, err) + } + if string(got) != want { + t.Errorf("CopyFile(%s, %s) copied %q, expected %q", src, dst, string(got), want) + } +} + +func TestCopyFileError(t *testing.T) { + tmp := t.TempDir() + dst := filepath.Join(tmp, "dst") + src := filepath.Join(tmp, "src") + + if err := CopyFile(src, dst); err == nil { + t.Errorf("CopyFile(%s, %s) succeeded for non-existent file, want error", src, dst) + } +}