diff --git a/unarchive/archive.go b/unarchive/archive.go index 1e15495..c2522e9 100644 --- a/unarchive/archive.go +++ b/unarchive/archive.go @@ -14,6 +14,7 @@ import ( type Unarchiver struct { BypassInspection bool + StripComponents int } var supportedArchives = []archiver.ExtensionChecker{ @@ -39,7 +40,7 @@ func (u *Unarchiver) IsSupportedArchive(filePath string) bool { // archiveName - The archive file name // destinationPath - The extraction destination directory func (u *Unarchiver) Unarchive(archivePath, archiveName, destinationPath string) error { - archive, err := byExtension(archiveName) + archive, err := u.byExtension(archiveName) if err != nil { return err } @@ -52,11 +53,12 @@ func (u *Unarchiver) Unarchive(archivePath, archiveName, destinationPath string) return err } } + return unarchiver.Unarchive(archivePath, destinationPath) } // Instead of using 'archiver.byExtension' that by default sets OverwriteExisting to false, we implement our own. -func byExtension(filename string) (interface{}, error) { +func (u *Unarchiver) byExtension(filename string) (interface{}, error) { var ec interface{} for _, c := range supportedArchives { if err := c.CheckExt(filename); err == nil { @@ -68,45 +70,54 @@ func byExtension(filename string) (interface{}, error) { case *archiver.Rar: archiveInstance := archiver.NewRar() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.Tar: archiveInstance := archiver.NewTar() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarBrotli: archiveInstance := archiver.NewTarBrotli() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarBz2: archiveInstance := archiver.NewTarBz2() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarGz: archiveInstance := archiver.NewTarGz() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarLz4: archiveInstance := archiver.NewTarLz4() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarSz: archiveInstance := archiver.NewTarSz() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarXz: archiveInstance := archiver.NewTarXz() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.TarZstd: archiveInstance := archiver.NewTarZstd() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.Zip: archiveInstance := archiver.NewZip() archiveInstance.OverwriteExisting = true + archiveInstance.StripComponents = u.StripComponents return archiveInstance, nil case *archiver.Gz: - archiver.NewGz() return archiver.NewGz(), nil case *archiver.Bz2: return archiver.NewBz2(), nil diff --git a/unarchive/archive_test.go b/unarchive/archive_test.go index fdbc489..bc932c0 100644 --- a/unarchive/archive_test.go +++ b/unarchive/archive_test.go @@ -10,19 +10,20 @@ import ( func TestUnarchive(t *testing.T) { tests := []string{"zip", "tar", "tar.gz"} + uarchiver := Unarchiver{} for _, extension := range tests { t.Run(extension, func(t *testing.T) { // Create temp directory tmpDir, createTempDirCallback := createTempDirWithCallbackAndAssert(t) defer createTempDirCallback() // Run unarchive on archive created on Unix - err := runUnarchive(t, "unix."+extension, "archives", filepath.Join(tmpDir, "unix")) + err := runUnarchive(t, uarchiver, "unix."+extension, "archives", filepath.Join(tmpDir, "unix")) assert.NoError(t, err) assert.FileExists(t, filepath.Join(tmpDir, "unix", "link")) assert.FileExists(t, filepath.Join(tmpDir, "unix", "dir", "file")) // Run unarchive on archive created on Windows - err = runUnarchive(t, "win."+extension, "archives", filepath.Join(tmpDir, "win")) + err = runUnarchive(t, uarchiver, "win."+extension, "archives", filepath.Join(tmpDir, "win")) assert.NoError(t, err) assert.FileExists(t, filepath.Join(tmpDir, "win", "link.lnk")) assert.FileExists(t, filepath.Join(tmpDir, "win", "dir", "file.txt")) @@ -41,6 +42,7 @@ var unarchiveSymlinksCases = []struct { func TestUnarchiveSymlink(t *testing.T) { testExtensions := []string{"zip", "tar", "tar.gz"} + uarchiver := Unarchiver{} for _, extension := range testExtensions { t.Run(extension, func(t *testing.T) { for _, testCase := range unarchiveSymlinksCases { @@ -50,7 +52,7 @@ func TestUnarchiveSymlink(t *testing.T) { defer createTempDirCallback() // Run unarchive - err := runUnarchive(t, testCase.prefix+"."+extension, "archives", tmpDir) + err := runUnarchive(t, uarchiver, testCase.prefix+"."+extension, "archives", tmpDir) assert.NoError(t, err) // Assert the all expected files were extracted @@ -77,6 +79,8 @@ func TestUnarchiveZipSlip(t *testing.T) { {"softlink-uncle", []string{"zip", "tar", "tar.gz"}, "a link can't lead to an ancestor directory"}, {"hardlink-tilde", []string{"tar", "tar.gz"}, "walking hardlink: illegal link path in archive: '~/../../../../../../../../../Users/Shared/sharedFile.txt'"}, } + + uarchiver := Unarchiver{} for _, test := range tests { t.Run(test.testType, func(t *testing.T) { // Create temp directory @@ -84,7 +88,7 @@ func TestUnarchiveZipSlip(t *testing.T) { defer createTempDirCallback() for _, archive := range test.archives { // Unarchive and make sure an error returns - err := runUnarchive(t, test.testType+"."+archive, "zipslip", tmpDir) + err := runUnarchive(t, uarchiver, test.testType+"."+archive, "zipslip", tmpDir) assert.Error(t, err) assert.Contains(t, err.Error(), test.errorSuffix) } @@ -92,8 +96,31 @@ func TestUnarchiveZipSlip(t *testing.T) { } } -func runUnarchive(t *testing.T, archiveFileName, sourceDir, targetDir string) error { +func TestUnarchiveWithStripComponents(t *testing.T) { + tests := []string{"zip", "tar", "tar.gz"} uarchiver := Unarchiver{} + uarchiver.StripComponents = 1 + for _, extension := range tests { + t.Run(extension, func(t *testing.T) { + // Create temp directory + tmpDir, createTempDirCallback := createTempDirWithCallbackAndAssert(t) + defer createTempDirCallback() + // Run unarchive on archive created on Unix + err := runUnarchive(t, uarchiver, "strip-components."+extension, "archives", filepath.Join(tmpDir, "unix")) + assert.NoError(t, err) + assert.DirExists(t, filepath.Join(tmpDir, "unix", "nested_folder_1")) + assert.DirExists(t, filepath.Join(tmpDir, "unix", "nested_folder_2")) + + // Run unarchive on archive created on Windows + err = runUnarchive(t, uarchiver, "strip-components."+extension, "archives", filepath.Join(tmpDir, "win")) + assert.NoError(t, err) + assert.DirExists(t, filepath.Join(tmpDir, "win", "nested_folder_1")) + assert.DirExists(t, filepath.Join(tmpDir, "win", "nested_folder_2")) + }) + } +} + +func runUnarchive(t *testing.T, uarchiver Unarchiver, archiveFileName, sourceDir, targetDir string) error { archivePath := filepath.Join("testdata", sourceDir, archiveFileName) assert.True(t, uarchiver.IsSupportedArchive(archivePath)) return uarchiver.Unarchive(filepath.Join("testdata", sourceDir, archiveFileName), archiveFileName, targetDir) diff --git a/unarchive/testdata/archives/strip-components.tar b/unarchive/testdata/archives/strip-components.tar new file mode 100644 index 0000000..ebf29fd Binary files /dev/null and b/unarchive/testdata/archives/strip-components.tar differ diff --git a/unarchive/testdata/archives/strip-components.tar.gz b/unarchive/testdata/archives/strip-components.tar.gz new file mode 100644 index 0000000..604792a Binary files /dev/null and b/unarchive/testdata/archives/strip-components.tar.gz differ diff --git a/unarchive/testdata/archives/strip-components.zip b/unarchive/testdata/archives/strip-components.zip new file mode 100644 index 0000000..c9c9f89 Binary files /dev/null and b/unarchive/testdata/archives/strip-components.zip differ