diff --git a/cmd/artifact/install/install.go b/cmd/artifact/install/install.go index b8f49a26..60a90669 100644 --- a/cmd/artifact/install/install.go +++ b/cmd/artifact/install/install.go @@ -337,7 +337,7 @@ func (o *artifactInstallOptions) RunArtifactInstall(ctx context.Context, args [] return err } // Extract artifact and move it to its destination directory - _, err = utils.ExtractTarGz(f, destDir, 0) + _, err = utils.ExtractTarGz(ctx, f, destDir, 0) if err != nil { return fmt.Errorf("cannot extract %q to %q: %w", result.Filename, destDir, err) } diff --git a/internal/follower/follower.go b/internal/follower/follower.go index 398c6fcf..36060d38 100644 --- a/internal/follower/follower.go +++ b/internal/follower/follower.go @@ -291,7 +291,7 @@ func (f *Follower) pull(ctx context.Context) (filePaths []string, res *oci.Regis } // Extract artifact and move it to its destination directory - filePaths, err = utils.ExtractTarGz(file, f.tmpDir, 0) + filePaths, err = utils.ExtractTarGz(ctx, file, f.tmpDir, 0) if err != nil { return filePaths, res, fmt.Errorf("unable to extract %q to %q: %w", res.Filename, f.tmpDir, err) } diff --git a/internal/utils/extract.go b/internal/utils/extract.go index 27cd7120..72ec6571 100644 --- a/internal/utils/extract.go +++ b/internal/utils/extract.go @@ -24,12 +24,30 @@ import ( "os" "path/filepath" "strings" + + "golang.org/x/net/context" ) +type link struct { + Name string + Path string +} + // ExtractTarGz extracts a *.tar.gz compressed archive and moves its content to destDir. // Returns a slice containing the full path of the extracted files. -func ExtractTarGz(gzipStream io.Reader, destDir string, stripPathComponents int) ([]string, error) { - var files []string +func ExtractTarGz(ctx context.Context, gzipStream io.Reader, destDir string, stripPathComponents int) ([]string, error) { + var ( + files []string + links []link + symlinks []link + err error + ) + + // We need an absolute path + destDir, err = filepath.Abs(destDir) + if err != nil { + return nil, err + } uncompressedStream, err := gzip.NewReader(gzipStream) if err != nil { @@ -37,34 +55,46 @@ func ExtractTarGz(gzipStream io.Reader, destDir string, stripPathComponents int) } tarReader := tar.NewReader(uncompressedStream) - for { - header, err := tarReader.Next() + select { + case <-ctx.Done(): + return nil, errors.New("interrupted") + default: + } + header, err := tarReader.Next() if errors.Is(err, io.EOF) { break } - if err != nil { return nil, err } - if strings.Contains(header.Name, "..") { return nil, fmt.Errorf("not allowed relative path in tar archive") } - strippedName := stripComponents(header.Name, stripPathComponents) + path := header.Name + if stripPathComponents > 0 { + path = stripComponents(path, stripPathComponents) + } + if path == "" { + continue + } + + if path, err = safeConcat(destDir, filepath.Clean(path)); err != nil { + // Skip paths that would escape destDir + continue + } + info := header.FileInfo() + files = append(files, path) switch header.Typeflag { case tar.TypeDir: - d := filepath.Join(destDir, strippedName) - if err = os.MkdirAll(filepath.Clean(d), 0o750); err != nil { + if err = os.MkdirAll(path, info.Mode()); err != nil { return nil, err } - files = append(files, d) case tar.TypeReg: - f := filepath.Join(destDir, strippedName) - outFile, err := os.Create(filepath.Clean(f)) + outFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, info.Mode()) if err != nil { return nil, err } @@ -76,27 +106,46 @@ func ExtractTarGz(gzipStream io.Reader, destDir string, stripPathComponents int) if err = outFile.Close(); err != nil { return nil, err } - if err = os.Chmod(filepath.Clean(f), header.FileInfo().Mode()); err != nil { - return nil, err - } - files = append(files, f) - case tar.TypeLink, tar.TypeSymlink: - strippedSrcName := stripComponents(header.Linkname, stripPathComponents) - fDst := filepath.Join(destDir, strippedName) - if header.Typeflag == tar.TypeSymlink { - err = os.Symlink(filepath.Clean(strippedSrcName), filepath.Clean(fDst)) - } else { - err = os.Link(filepath.Clean(strippedSrcName), filepath.Clean(fDst)) + case tar.TypeLink: + name := header.Linkname + if stripPathComponents > 0 { + name = stripComponents(name, stripPathComponents) } - if err != nil { - return nil, err + if name == "" { + continue } - files = append(files, fDst) + + name = filepath.Join(destDir, filepath.Clean(name)) + links = append(links, link{Path: path, Name: name}) + case tar.TypeSymlink: + symlinks = append(symlinks, link{Path: path, Name: header.Linkname}) default: return nil, fmt.Errorf("extractTarGz: uknown type: %b in %s", header.Typeflag, header.Name) } } + // Now we make another pass creating the links + for i := range links { + select { + case <-ctx.Done(): + return nil, errors.New("interrupted") + default: + } + if err = os.Link(links[i].Name, links[i].Path); err != nil { + return nil, err + } + } + + for i := range symlinks { + select { + case <-ctx.Done(): + return nil, errors.New("interrupted") + default: + } + if err = os.Symlink(symlinks[i].Name, symlinks[i].Path); err != nil { + return nil, err + } + } return files, nil } @@ -104,11 +153,22 @@ func stripComponents(headerName string, stripComponents int) string { if stripComponents == 0 { return headerName } - names := strings.FieldsFunc(headerName, func(r rune) bool { - return r == os.PathSeparator - }) + names := strings.Split(headerName, string(filepath.Separator)) if len(names) < stripComponents { return headerName } - return filepath.Clean(strings.Join(names[stripComponents:], string(os.PathSeparator))) + return filepath.Clean(strings.Join(names[stripComponents:], string(filepath.Separator))) +} + +// safeConcat concatenates destDir and name +// but returns an error if the resulting path points outside 'destDir'. +func safeConcat(destDir, name string) (string, error) { + res := filepath.Join(destDir, name) + if !strings.HasSuffix(destDir, string(os.PathSeparator)) { + destDir += string(os.PathSeparator) + } + if !strings.HasPrefix(res, destDir) { + return res, fmt.Errorf("unsafe path concatenation: '%s' with '%s'", destDir, name) + } + return res, nil } diff --git a/internal/utils/extract_test.go b/internal/utils/extract_test.go index 411df7ba..c70916c1 100644 --- a/internal/utils/extract_test.go +++ b/internal/utils/extract_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "golang.org/x/net/context" ) const ( @@ -112,7 +113,7 @@ func TestExtractTarGz(t *testing.T) { }) // Create dest folder - destDir := "./test/" + destDir := "./test" err = os.MkdirAll(destDir, 0o750) assert.NoError(t, err) t.Cleanup(func() { @@ -126,10 +127,11 @@ func TestExtractTarGz(t *testing.T) { f.Close() }) - list, err := ExtractTarGz(f, destDir, 0) + list, err := ExtractTarGz(context.TODO(), f, destDir, 0) assert.NoError(t, err) // Final checks + assert.NotEmpty(t, list) // All extracted files are ok for _, f := range list { @@ -138,8 +140,10 @@ func TestExtractTarGz(t *testing.T) { } // Extracted folder contains all source files (plus folders) + absDestDir, err := filepath.Abs(destDir) + assert.NoError(t, err) for _, f := range files { - path := filepath.Join(destDir, f) + path := filepath.Join(absDestDir, f) assert.Contains(t, list, path) } } @@ -168,7 +172,7 @@ func TestExtractTarGzStripComponents(t *testing.T) { }) // Create dest folder - destdirStrip := "./test_strip/" + destdirStrip := "./test_strip" err = os.MkdirAll(destdirStrip, 0o750) assert.NoError(t, err) t.Cleanup(func() { @@ -182,10 +186,11 @@ func TestExtractTarGzStripComponents(t *testing.T) { f.Close() }) // NOTE that here we strip first component - list, err := ExtractTarGz(f, destdirStrip, 1) + list, err := ExtractTarGz(context.TODO(), f, destdirStrip, 1) assert.NoError(t, err) // Final checks + assert.NotEmpty(t, list) // All extracted files are ok for _, f := range list { @@ -194,10 +199,12 @@ func TestExtractTarGzStripComponents(t *testing.T) { } // Extracted folder contains all source files (plus folders) + absDestDirStrip, err := filepath.Abs(destdirStrip) + assert.NoError(t, err) for _, f := range files { // We stripped first component (ie: srcDir) ff := strings.TrimPrefix(f, srcDir) - path := filepath.Join(destdirStrip, ff) + path := filepath.Join(absDestDirStrip, ff) assert.Contains(t, list, path) } } diff --git a/pkg/driver/distro/distro.go b/pkg/driver/distro/distro.go index d11f7203..2fb348c0 100644 --- a/pkg/driver/distro/distro.go +++ b/pkg/driver/distro/distro.go @@ -319,7 +319,7 @@ func downloadKernelSrc(ctx context.Context, return env, err } - _, err = utils.ExtractTarGz(resp.Body, fullKernelDir, stripComponents) + _, err = utils.ExtractTarGz(ctx, resp.Body, fullKernelDir, stripComponents) if err != nil { return env, err }