diff --git a/cli/upgrade.go b/cli/upgrade.go index e38b8725..2f0a4215 100644 --- a/cli/upgrade.go +++ b/cli/upgrade.go @@ -60,16 +60,43 @@ func windowsDownload(keyConjurerRcPath string) error { // defaultDownload replaces the currently executing binary by writing over it directly. func defaultDownload(ctx context.Context, client *http.Client, keyConjurerRcPath string) error { - f, err := os.OpenFile(keyConjurerRcPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0744) + tmp, err := os.CreateTemp(os.TempDir(), "keyconjurer") if err != nil { - return fmt.Errorf("unable to open %q: %w", keyConjurerRcPath, err) + return fmt.Errorf("failed to create temporary file for upgrade: %w", err) } - defer f.Close() - if err := DownloadLatestBinary(ctx, client, f); err != nil { + defer tmp.Close() + src, err := DownloadLatestBinary(ctx, client, tmp) + if err != nil { return fmt.Errorf("unable to download the latest binary: %w", err) } + if err := tmp.Close(); err != nil { + return fmt.Errorf("could not close tmp file: %w", err) + } + + bytesCopied, err := io.Copy(tmp, src) + if err != nil { + return fmt.Errorf("failed to copy new keyconjurer: %s", err) + } + + // Re-open the temporary file for reading and copy: + r, err := os.Open(tmp.Name()) + if err != nil { + return fmt.Errorf("could not open temporary file %s: %w", tmp.Name(), err) + } + + kc, _ := os.OpenFile(keyConjurerRcPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0744) + if err != nil { + return fmt.Errorf("unable to open %q: %w", keyConjurerRcPath, err) + } + + bytesCopied2, err := io.Copy(kc, r) + if err != nil || bytesCopied != bytesCopied2 { + // If an error occurs here, KeyConjurer has been overwritten and is potentially corrrupted + return fmt.Errorf("failed to copy new keyconjurer contents - keyconjurer is potentially corrupted and may need to be downloaded again: %w", err) + } + return nil } @@ -92,22 +119,21 @@ func getBinaryName() string { } // DownloadLatestBinary downloads the latest keyconjurer binary from the web. -func DownloadLatestBinary(ctx context.Context, client *http.Client, w io.Writer) error { +func DownloadLatestBinary(ctx context.Context, client *http.Client, w io.Writer) (io.ReadCloser, error) { binaryURL := fmt.Sprintf("%s/%s", DownloadURL, getBinaryName()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, binaryURL, nil) if err != nil { - return fmt.Errorf("could not upgrade: %w", err) + return nil, fmt.Errorf("could not upgrade: %w", err) } res, err := client.Do(req) if err != nil { - return fmt.Errorf("could not upgrade: %w", err) + return nil, fmt.Errorf("could not upgrade: %w", err) } if res.StatusCode != 200 { - return errors.New("could not upgrade: response did not indicate success - are you being blocked by the server?") + return nil, errors.New("could not upgrade: response did not indicate success - are you being blocked by the server?") } - _, err = io.Copy(w, res.Body) - return err + return req.Body, nil }