diff --git a/vacation/vacation.go b/vacation/vacation.go index d6f005af..4a3ec3ce 100644 --- a/vacation/vacation.go +++ b/vacation/vacation.go @@ -74,6 +74,7 @@ func (ta TarArchive) Decompress(destination string) error { // tarball, which can be seen in the test around there being no directory // metadata. directories := map[string]interface{}{} + type header struct { name string linkname string @@ -160,11 +161,14 @@ func (ta TarArchive) Decompress(destination string) error { } for _, h := range symlinkHeaders { - _, err := filepath.EvalSymlinks(filepath.Join(destination, h.linkname)) + // Check to see if the file that will be linked to is valid for symlinking + _, err := filepath.EvalSymlinks(filepath.Join(filepath.Dir(h.path), h.linkname)) if err != nil { return err } + // Check that the file being symlinked to is inside the destination + // directory err = checkExtractPath(filepath.Join(filepath.Dir(h.name), h.linkname), destination) if err != nil { return err @@ -299,6 +303,14 @@ func NewZipArchive(inputReader io.Reader) ZipArchive { // Decompress reads from ZipArchive and writes files into the destination // specified. func (z ZipArchive) Decompress(destination string) error { + type header struct { + name string + linkname string + path string + } + + var symlinkHeaders []header + // Have to convert an io.Reader into a bytes.Reader which implements the // ReadAt function making it compatible with the io.ReaderAt inteface which // required for zip.NewReader @@ -340,15 +352,14 @@ func (z ZipArchive) Decompress(destination string) error { return err } - err = checkExtractPath(filepath.Join(filepath.Dir(f.Name), string(linkname)), destination) - if err != nil { - return err - } + // Collect all of the headers for symlinks so that they can be verified + // after all other files are written + symlinkHeaders = append(symlinkHeaders, header{ + name: f.Name, + linkname: string(linkname), + path: path, + }) - err = os.Symlink(string(linkname), path) - if err != nil { - return fmt.Errorf("failed to unzip symlink: %w", err) - } default: err = os.MkdirAll(filepath.Dir(path), os.ModePerm) if err != nil { @@ -374,6 +385,26 @@ func (z ZipArchive) Decompress(destination string) error { } } + for _, h := range symlinkHeaders { + // Check to see if the file that will be linked to is valid for symlinking + _, err := filepath.EvalSymlinks(filepath.Join(filepath.Dir(h.path), h.linkname)) + if err != nil { + return err + } + + // Check that the file being symlinked to is inside the destination + // directory + err = checkExtractPath(filepath.Join(filepath.Dir(h.name), h.linkname), destination) + if err != nil { + return err + } + + err = os.Symlink(h.linkname, h.path) + if err != nil { + return fmt.Errorf("failed to unzip symlink: %w", err) + } + } + return nil } diff --git a/vacation/vacation_tar_test.go b/vacation/vacation_tar_test.go index d5a119ff..c356e3c9 100644 --- a/vacation/vacation_tar_test.go +++ b/vacation/vacation_tar_test.go @@ -229,7 +229,7 @@ func testVacationTar(t *testing.T, context spec.G, it spec.S) { }) }) - context("when it tries to symlinkto a file that does not exist", func() { + context("when it tries to symlink to a file that does not exist", func() { var zipSlipSymlinkTar vacation.TarArchive it.Before(func() { @@ -280,7 +280,7 @@ func testVacationTar(t *testing.T, context spec.G, it spec.S) { }) }) - context("when it tries to decompress a broken symlink", func() { + context("when the symlink creation fails", func() { var brokenSymlinkTar vacation.TarArchive it.Before(func() { diff --git a/vacation/vacation_zip_test.go b/vacation/vacation_zip_test.go index 876be27b..38a8994e 100644 --- a/vacation/vacation_zip_test.go +++ b/vacation/vacation_zip_test.go @@ -33,13 +33,22 @@ func testVacationZip(t *testing.T, context spec.G, it spec.S) { buffer := bytes.NewBuffer(nil) zw := zip.NewWriter(buffer) + fileHeader := &zip.FileHeader{Name: "symlink"} + fileHeader.SetMode(0755 | os.ModeSymlink) + + symlink, err := zw.CreateHeader(fileHeader) + Expect(err).NotTo(HaveOccurred()) + + _, err = symlink.Write([]byte(filepath.Join("some-dir", "some-other-dir", "some-file"))) + Expect(err).NotTo(HaveOccurred()) + _, err = zw.Create("some-dir/") Expect(err).NotTo(HaveOccurred()) _, err = zw.Create(fmt.Sprintf("%s/", filepath.Join("some-dir", "some-other-dir"))) Expect(err).NotTo(HaveOccurred()) - fileHeader := &zip.FileHeader{Name: filepath.Join("some-dir", "some-other-dir", "some-file")} + fileHeader = &zip.FileHeader{Name: filepath.Join("some-dir", "some-other-dir", "some-file")} fileHeader.SetMode(0644) nestedFile, err := zw.CreateHeader(fileHeader) @@ -59,15 +68,6 @@ func testVacationZip(t *testing.T, context spec.G, it spec.S) { Expect(err).NotTo(HaveOccurred()) } - fileHeader = &zip.FileHeader{Name: "symlink"} - fileHeader.SetMode(0755 | os.ModeSymlink) - - symlink, err := zw.CreateHeader(fileHeader) - Expect(err).NotTo(HaveOccurred()) - - _, err = symlink.Write([]byte(filepath.Join("some-dir", "some-other-dir", "some-file"))) - Expect(err).NotTo(HaveOccurred()) - Expect(zw.Close()).To(Succeed()) zipArchive = vacation.NewZipArchive(bytes.NewReader(buffer.Bytes())) @@ -194,7 +194,7 @@ func testVacationZip(t *testing.T, context spec.G, it spec.S) { }) }) - context("when it tries to symlink that tries to link to a file outside of the directory", func() { + context("when it tries to symlink to a file that does not exist", func() { var buffer *bytes.Buffer it.Before(func() { var err error @@ -218,14 +218,18 @@ func testVacationZip(t *testing.T, context spec.G, it spec.S) { readyArchive := vacation.NewZipArchive(buffer) err := readyArchive.Decompress(tempDir) - Expect(err).To(MatchError(ContainSubstring(fmt.Sprintf("illegal file path %q: the file path does not occur within the destination directory", filepath.Join("..", "some-file"))))) + Expect(err).To(MatchError(ContainSubstring("no such file or directory"))) }) }) - context("when it fails to unzip a symlink", func() { + context("when it tries to symlink that tries to link to a file outside of the directory", func() { var buffer *bytes.Buffer it.Before(func() { var err error + + Expect(os.MkdirAll(filepath.Join(tempDir, "sub-dir"), os.ModePerm)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(tempDir, "some-file"), nil, 0644)).To(Succeed()) + buffer = bytes.NewBuffer(nil) zw := zip.NewWriter(buffer) @@ -235,16 +239,43 @@ func testVacationZip(t *testing.T, context spec.G, it spec.S) { symlink, err := zw.CreateHeader(header) Expect(err).NotTo(HaveOccurred()) - _, err = symlink.Write([]byte(filepath.Join("some", "path", "to", "a", "target"))) + _, err = symlink.Write([]byte(filepath.Join("..", "some-file"))) Expect(err).NotTo(HaveOccurred()) Expect(zw.Close()).To(Succeed()) - Expect(os.Chmod(tempDir, 0000)).To(Succeed()) }) - it.After(func() { - Expect(os.Chmod(tempDir, os.ModePerm)).To(Succeed()) + it("returns an error", func() { + readyArchive := vacation.NewZipArchive(buffer) + + err := readyArchive.Decompress(filepath.Join(tempDir, "sub-dir")) + Expect(err).To(MatchError(ContainSubstring(fmt.Sprintf("illegal file path %q: the file path does not occur within the destination directory", filepath.Join("..", "some-file"))))) + }) + }) + + context("when the symlink creation fails", func() { + var buffer *bytes.Buffer + it.Before(func() { + var err error + buffer = bytes.NewBuffer(nil) + zw := zip.NewWriter(buffer) + + header := &zip.FileHeader{Name: "symlink"} + header.SetMode(0755 | os.ModeSymlink) + + symlink, err := zw.CreateHeader(header) + Expect(err).NotTo(HaveOccurred()) + + _, err = symlink.Write([]byte(filepath.Join("some-file"))) + Expect(err).NotTo(HaveOccurred()) + + Expect(zw.Close()).To(Succeed()) + + // Create a symlink in the target to force the new symlink create to + // fail + Expect(os.WriteFile(filepath.Join(tempDir, "some-file"), nil, 0644)).To(Succeed()) + Expect(os.Symlink("some-file", filepath.Join(tempDir, "symlink"))).To(Succeed()) }) it("returns an error", func() {