diff --git a/io/fileutils.go b/io/fileutils.go index 77cff22..091b401 100644 --- a/io/fileutils.go +++ b/io/fileutils.go @@ -240,7 +240,15 @@ func GetFileInfo(path string, followSymlink bool) (fileInfo os.FileInfo, err err // Close the reader/writer and append the error to the given error. func Close(closer io.Closer, err *error) { - if closeErr := closer.Close(); closeErr != nil { - *err = errors.Join(*err, fmt.Errorf("failed to close %T: %w", closer, closeErr)) + var closeErr error + if closeErr = closer.Close(); closeErr == nil { + return } + closeErr = fmt.Errorf("failed to close %T: %w", closer, closeErr) + if err == nil { + err = &closeErr + return + } + *err = errors.Join(*err, closeErr) + return } diff --git a/io/fileutils_test.go b/io/fileutils_test.go index f727854..9742a9c 100644 --- a/io/fileutils_test.go +++ b/io/fileutils_test.go @@ -2,17 +2,17 @@ package io import ( "errors" - "github.com/stretchr/testify/assert" "os" - "path" + "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestClose(t *testing.T) { var err error - t.TempDir() - f, err := os.Create(path.Join(t.TempDir(), "test")) + f, err := os.Create(filepath.Join(t.TempDir(), "test")) assert.NoError(t, err) Close(f, &err) @@ -26,4 +26,12 @@ func TestClose(t *testing.T) { err = errors.New("original error") Close(f, &err) assert.Len(t, strings.Split(err.Error(), "\n"), 2) + + nilErr := new(error) + Close(f, nilErr) + assert.NotNil(t, nilErr) +} + +func getNilErr() error { + return nil }