diff --git a/cli/cage/upgrade/upgrade.go b/cli/cage/upgrade/upgrade.go index 259e75c..81749a4 100644 --- a/cli/cage/upgrade/upgrade.go +++ b/cli/cage/upgrade/upgrade.go @@ -74,7 +74,7 @@ func (u *upgrader) Upgrade(p *Input) error { return err } log.Infof("downloading binary %s...", binaryAsset.GetName()) - cageRd, err := unzipArchive(binaryAsset.GetBrowserDownloadURL(), checksum) + newCageFile, err := unzipArchive(binaryAsset.GetBrowserDownloadURL(), checksum) if err != nil { return err } @@ -87,7 +87,7 @@ func (u *upgrader) Upgrade(p *Input) error { } targetPath = exec } - if err := swapFiles(targetPath, cageRd); err != nil { + if err := swapFiles(targetPath, newCageFile); err != nil { return err } log.Infof("upgraded to %s", version) @@ -124,39 +124,48 @@ func (u *upgrader) FindLatestRelease(pre bool) (*github.RepositoryRelease, error func unzipArchive( assetUrl string, checksum []byte, -) (io.ReadCloser, error) { +) (string, error) { resp, err := http.DefaultClient.Get(assetUrl) if err != nil { - return nil, err + return "", err } defer resp.Body.Close() zipdest, err := os.CreateTemp("", "cage") if err != nil { - return nil, err + return "", err } defer zipdest.Close() sha := sha256.New() if _, err := io.Copy(zipdest, io.TeeReader(resp.Body, sha)); err != nil { - return nil, err + return "", err } actChecksum := sha.Sum(nil) if !bytes.Equal(checksum, actChecksum) { - return nil, xerrors.Errorf("checksum mismatch: expected %x, got %x", checksum, actChecksum) + return "", xerrors.Errorf("checksum mismatch: expected %x, got %x", checksum, actChecksum) } ziprd, err := zip.OpenReader(zipdest.Name()) if err != nil { - return nil, err + return "", err } defer ziprd.Close() cageRd, err := ziprd.Open("cage") if err != nil { - return nil, err + return "", err } - return cageRd, nil + defer cageRd.Close() + cageDest, err := os.CreateTemp("", "cage") + if err != nil { + return "", err + } + defer cageDest.Close() + if _, err := io.Copy(cageDest, cageRd); err != nil { + return "", err + } + return cageDest.Name(), nil } func parseChecksums(url string, file string) ([]byte, error) { @@ -195,20 +204,12 @@ func parseChecksums(url string, file string) ([]byte, error) { func swapFiles( targetPath string, - newReader io.ReadCloser, + newPath string, ) error { - defer newReader.Close() - // Write to a new file newFilepath := targetPath + ".new" - newFile, err := os.OpenFile(newFilepath, os.O_CREATE|os.O_WRONLY, 0755) - if err != nil { - return err - } - defer newFile.Close() - if _, err := io.Copy(newFile, newReader); err != nil { + if err := os.Rename(newPath, newFilepath); err != nil { return err } - // Swap files oldFilepath := targetPath + ".old" if err := os.Rename(targetPath, oldFilepath); err != nil { return err