Skip to content

Commit

Permalink
Allow Stripping Components for Unarchiving (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
omerzi authored Jan 8, 2024
1 parent ca20703 commit 69aad24
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 8 deletions.
17 changes: 14 additions & 3 deletions unarchive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

type Unarchiver struct {
BypassInspection bool
StripComponents int
}

var supportedArchives = []archiver.ExtensionChecker{
Expand All @@ -39,7 +40,7 @@ func (u *Unarchiver) IsSupportedArchive(filePath string) bool {
// archiveName - The archive file name
// destinationPath - The extraction destination directory
func (u *Unarchiver) Unarchive(archivePath, archiveName, destinationPath string) error {
archive, err := byExtension(archiveName)
archive, err := u.byExtension(archiveName)
if err != nil {
return err
}
Expand All @@ -52,11 +53,12 @@ func (u *Unarchiver) Unarchive(archivePath, archiveName, destinationPath string)
return err
}
}

return unarchiver.Unarchive(archivePath, destinationPath)
}

// Instead of using 'archiver.byExtension' that by default sets OverwriteExisting to false, we implement our own.
func byExtension(filename string) (interface{}, error) {
func (u *Unarchiver) byExtension(filename string) (interface{}, error) {
var ec interface{}
for _, c := range supportedArchives {
if err := c.CheckExt(filename); err == nil {
Expand All @@ -68,45 +70,54 @@ func byExtension(filename string) (interface{}, error) {
case *archiver.Rar:
archiveInstance := archiver.NewRar()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.Tar:
archiveInstance := archiver.NewTar()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarBrotli:
archiveInstance := archiver.NewTarBrotli()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarBz2:
archiveInstance := archiver.NewTarBz2()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarGz:
archiveInstance := archiver.NewTarGz()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarLz4:
archiveInstance := archiver.NewTarLz4()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarSz:
archiveInstance := archiver.NewTarSz()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarXz:
archiveInstance := archiver.NewTarXz()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.TarZstd:
archiveInstance := archiver.NewTarZstd()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.Zip:
archiveInstance := archiver.NewZip()
archiveInstance.OverwriteExisting = true
archiveInstance.StripComponents = u.StripComponents
return archiveInstance, nil
case *archiver.Gz:
archiver.NewGz()
return archiver.NewGz(), nil
case *archiver.Bz2:
return archiver.NewBz2(), nil
Expand Down
37 changes: 32 additions & 5 deletions unarchive/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@ import (

func TestUnarchive(t *testing.T) {
tests := []string{"zip", "tar", "tar.gz"}
uarchiver := Unarchiver{}
for _, extension := range tests {
t.Run(extension, func(t *testing.T) {
// Create temp directory
tmpDir, createTempDirCallback := createTempDirWithCallbackAndAssert(t)
defer createTempDirCallback()
// Run unarchive on archive created on Unix
err := runUnarchive(t, "unix."+extension, "archives", filepath.Join(tmpDir, "unix"))
err := runUnarchive(t, uarchiver, "unix."+extension, "archives", filepath.Join(tmpDir, "unix"))
assert.NoError(t, err)
assert.FileExists(t, filepath.Join(tmpDir, "unix", "link"))
assert.FileExists(t, filepath.Join(tmpDir, "unix", "dir", "file"))

// Run unarchive on archive created on Windows
err = runUnarchive(t, "win."+extension, "archives", filepath.Join(tmpDir, "win"))
err = runUnarchive(t, uarchiver, "win."+extension, "archives", filepath.Join(tmpDir, "win"))
assert.NoError(t, err)
assert.FileExists(t, filepath.Join(tmpDir, "win", "link.lnk"))
assert.FileExists(t, filepath.Join(tmpDir, "win", "dir", "file.txt"))
Expand All @@ -41,6 +42,7 @@ var unarchiveSymlinksCases = []struct {

func TestUnarchiveSymlink(t *testing.T) {
testExtensions := []string{"zip", "tar", "tar.gz"}
uarchiver := Unarchiver{}
for _, extension := range testExtensions {
t.Run(extension, func(t *testing.T) {
for _, testCase := range unarchiveSymlinksCases {
Expand All @@ -50,7 +52,7 @@ func TestUnarchiveSymlink(t *testing.T) {
defer createTempDirCallback()

// Run unarchive
err := runUnarchive(t, testCase.prefix+"."+extension, "archives", tmpDir)
err := runUnarchive(t, uarchiver, testCase.prefix+"."+extension, "archives", tmpDir)
assert.NoError(t, err)

// Assert the all expected files were extracted
Expand All @@ -77,23 +79,48 @@ func TestUnarchiveZipSlip(t *testing.T) {
{"softlink-uncle", []string{"zip", "tar", "tar.gz"}, "a link can't lead to an ancestor directory"},
{"hardlink-tilde", []string{"tar", "tar.gz"}, "walking hardlink: illegal link path in archive: '~/../../../../../../../../../Users/Shared/sharedFile.txt'"},
}

uarchiver := Unarchiver{}
for _, test := range tests {
t.Run(test.testType, func(t *testing.T) {
// Create temp directory
tmpDir, createTempDirCallback := createTempDirWithCallbackAndAssert(t)
defer createTempDirCallback()
for _, archive := range test.archives {
// Unarchive and make sure an error returns
err := runUnarchive(t, test.testType+"."+archive, "zipslip", tmpDir)
err := runUnarchive(t, uarchiver, test.testType+"."+archive, "zipslip", tmpDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), test.errorSuffix)
}
})
}
}

func runUnarchive(t *testing.T, archiveFileName, sourceDir, targetDir string) error {
func TestUnarchiveWithStripComponents(t *testing.T) {
tests := []string{"zip", "tar", "tar.gz"}
uarchiver := Unarchiver{}
uarchiver.StripComponents = 1
for _, extension := range tests {
t.Run(extension, func(t *testing.T) {
// Create temp directory
tmpDir, createTempDirCallback := createTempDirWithCallbackAndAssert(t)
defer createTempDirCallback()
// Run unarchive on archive created on Unix
err := runUnarchive(t, uarchiver, "strip-components."+extension, "archives", filepath.Join(tmpDir, "unix"))
assert.NoError(t, err)
assert.DirExists(t, filepath.Join(tmpDir, "unix", "nested_folder_1"))
assert.DirExists(t, filepath.Join(tmpDir, "unix", "nested_folder_2"))

// Run unarchive on archive created on Windows
err = runUnarchive(t, uarchiver, "strip-components."+extension, "archives", filepath.Join(tmpDir, "win"))
assert.NoError(t, err)
assert.DirExists(t, filepath.Join(tmpDir, "win", "nested_folder_1"))
assert.DirExists(t, filepath.Join(tmpDir, "win", "nested_folder_2"))
})
}
}

func runUnarchive(t *testing.T, uarchiver Unarchiver, archiveFileName, sourceDir, targetDir string) error {
archivePath := filepath.Join("testdata", sourceDir, archiveFileName)
assert.True(t, uarchiver.IsSupportedArchive(archivePath))
return uarchiver.Unarchive(filepath.Join("testdata", sourceDir, archiveFileName), archiveFileName, targetDir)
Expand Down
Binary file added unarchive/testdata/archives/strip-components.tar
Binary file not shown.
Binary file added unarchive/testdata/archives/strip-components.tar.gz
Binary file not shown.
Binary file added unarchive/testdata/archives/strip-components.zip
Binary file not shown.

0 comments on commit 69aad24

Please sign in to comment.