diff --git a/CHANGELOG.md b/CHANGELOG.md index 7436896..12b6f70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,22 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] ## +### Changed ### +- Passing the `S_ISUID` or `S_ISGID` modes to `MkdirAllInRoot` will now return + an explicit error saying that those bits are ignored by `mkdirat(2)`. In the + past a different error was returned, but since the silent ignoring behaviour + is codified in the man pages a more explicit error seems apt. While silently + ignoring these bits would be the most compatible option, it could lead to + users thinking their code sets these bits when it doesn't. Programs that need + to deal with compatibility can mask the bits themselves. (#23, #25) + +## Fixes ## +- If a directory has `S_ISGID` set, then all child directories will have + `S_ISGID` set when created and a different gid will be used for any inode + created under the directory. Previously, the "expected owner and mode" + validation in `securejoin.MkdirAll` did not correctly handle this. We now + correctly handle this case. (#24, #25) + ## [0.3.1] - 2024-07-23 ## ### Changed ### diff --git a/mkdir_linux.go b/mkdir_linux.go index ad2bd79..49ffdbe 100644 --- a/mkdir_linux.go +++ b/mkdir_linux.go @@ -46,6 +46,13 @@ func MkdirAllHandle(root *os.File, unsafePath string, mode int) (_ *os.File, Err if mode&^0o7777 != 0 { return nil, fmt.Errorf("%w for mkdir 0o%.3o", errInvalidMode, mode) } + // On Linux, mkdirat(2) (and os.Mkdir) silently ignore the suid and sgid + // bits. We could also silently ignore them but since we have very few + // users it seems more prudent to return an error so users notice that + // these bits will not be set. + if mode&^0o1777 != 0 { + return nil, fmt.Errorf("%w for mkdir 0o%.3o: suid and sgid are ignored by mkdir", errInvalidMode, mode) + } // Try to open as much of the path as possible. currentDir, remainingPath, err := partialLookupInRoot(root, unsafePath) @@ -120,6 +127,17 @@ func MkdirAllHandle(root *os.File, unsafePath string, mode int) (_ *os.File, Err expectedGid = uint32(unix.Getegid()) ) + // The setgid bit (S_ISGID = 0o2000) is inherited to child directories and + // affects the group of any inodes created in said directory, so if the + // starting directory has it set we need to adjust our expected mode and + // owner to match. + if st, err := fstatFile(currentDir); err != nil { + return nil, fmt.Errorf("failed to stat starting path for mkdir %q: %w", currentDir.Name(), err) + } else if st.Mode&unix.S_ISGID == unix.S_ISGID { + expectedMode |= unix.S_ISGID + expectedGid = st.Gid + } + // Create the remaining components. for _, part := range remainingParts { switch part { diff --git a/mkdir_linux_test.go b/mkdir_linux_test.go index bdad1c9..b32df3f 100644 --- a/mkdir_linux_test.go +++ b/mkdir_linux_test.go @@ -19,7 +19,85 @@ import ( "golang.org/x/sys/unix" ) -func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePath string, mode int) error) { +type mkdirAllFunc func(t *testing.T, root, unsafePath string, mode int) error + +var mkdirAll_MkdirAll mkdirAllFunc = func(t *testing.T, root, unsafePath string, mode int) error { + // We can't check expectedPath here. + return MkdirAll(root, unsafePath, mode) +} + +var mkdirAll_MkdirAllHandle mkdirAllFunc = func(t *testing.T, root, unsafePath string, mode int) error { + // Same logic as MkdirAll. + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + if err != nil { + return err + } + defer rootDir.Close() + handle, err := MkdirAllHandle(rootDir, unsafePath, mode) + if err != nil { + return err + } + defer handle.Close() + + // We can use SecureJoin here becuase we aren't being attacked in this + // particular test. Obviously this check is bogus for actual programs. + expectedPath, err := SecureJoin(root, unsafePath) + require.NoError(t, err) + + // Now double-check that the handle is correct. + gotPath, err := procSelfFdReadlink(handle) + require.NoError(t, err, "get real path of returned handle") + assert.Equal(t, expectedPath, gotPath, "wrong final path from MkdirAllHandle") + // Also check that the f.Name() is correct while we're at it (this is + // not always guaranteed but it's better to try at least). + assert.Equal(t, expectedPath, handle.Name(), "handle from MkdirAllHandle has the wrong .Name()") + return nil +} + +func checkMkdirAll(t *testing.T, mkdirAll mkdirAllFunc, root, unsafePath string, mode, expectedMode int, expectedErr error) { + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + require.NoError(t, err) + defer rootDir.Close() + + // Before trying to make the tree, figure out what components don't exist + // yet so we can check them later. + handle, remainingPath, err := partialLookupInRoot(rootDir, unsafePath) + handleName := "" + if handle != nil { + handleName = handle.Name() + defer handle.Close() + } + defer func() { + if t.Failed() { + t.Logf("partialLookupInRoot(%s, %s) -> (<%s>, %s, %v)", root, unsafePath, handleName, remainingPath, err) + } + }() + + // Actually make the tree. + err = mkdirAll(t, root, unsafePath, mode) + assert.ErrorIsf(t, err, expectedErr, "MkdirAll(%q, %q)", root, unsafePath) + + remainingPath = filepath.Join("/", remainingPath) + for remainingPath != filepath.Dir(remainingPath) { + stat, err := fstatatFile(handle, "./"+remainingPath, unix.AT_SYMLINK_NOFOLLOW) + if expectedErr == nil { + // Check that the new components have the right mode. + if assert.NoErrorf(t, err, "unexpected error when checking new directory %q", remainingPath) { + assert.Equalf(t, uint32(unix.S_IFDIR|expectedMode), stat.Mode, "new directory %q has the wrong mode", remainingPath) + } + } else { + // Check that none of the components are directories (i.e. make + // sure that the MkdirAll was a no-op). + if err == nil { + assert.NotEqualf(t, uint32(unix.S_IFDIR), stat.Mode&unix.S_IFMT, "failed MkdirAll created a new directory at %q", remainingPath) + } + } + // Jump up a level. + remainingPath = filepath.Dir(remainingPath) + } +} + +func testMkdirAll_Basic(t *testing.T, mkdirAll mkdirAllFunc) { // We create a new tree for each test, but the template is the same. tree := []string{ "dir a", @@ -47,12 +125,16 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa // Symlink loop. "dir loop", "symlink loop/link ../loop/link", + // S_ISGID directory. + "dir sgid-self ::2755", + "dir sgid-sticky-self ::3755", } withWithoutOpenat2(t, true, func(t *testing.T) { for name, test := range map[string]struct { - unsafePath string - expectedErr error + unsafePath string + expectedErr error + expectedModeBits int }{ "existing": {unsafePath: "a"}, "basic": {unsafePath: "a/b/c/d/e/f/g/h/i/j"}, @@ -99,99 +181,71 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa "loop-trailing": {unsafePath: "loop/link", expectedErr: unix.ELOOP}, "loop-basic": {unsafePath: "loop/link/foo", expectedErr: unix.ELOOP}, "loop-dotdot": {unsafePath: "loop/link/../foo", expectedErr: unix.ELOOP}, + // Make sure the S_ISGID handling is correct. + "sgid-dir-ownedbyus": {unsafePath: "sgid-self/foo/bar/baz", expectedModeBits: unix.S_ISGID}, + "sgid-sticky-dir-ownedbyus": {unsafePath: "sgid-sticky-self/foo/bar/baz", expectedModeBits: unix.S_ISGID}, } { test := test // copy iterator t.Run(name, func(t *testing.T) { root := createTree(t, tree...) - - rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) - require.NoError(t, err) - defer rootDir.Close() - - // Before trying to make the tree, figure out what - // components don't exist yet so we can check them later. - handle, remainingPath, err := partialLookupInRoot(rootDir, test.unsafePath) - handleName := "" - if handle != nil { - handleName = handle.Name() - defer handle.Close() - } - defer func() { - if t.Failed() { - t.Logf("partialLookupInRoot(%s, %s) -> (<%s>, %s, %v)", root, test.unsafePath, handleName, remainingPath, err) - } - }() - - // This mode is different to the one set up by createTree. - const expectedMode = 0o711 - - // Actually make the tree. - err = mkdirAll(t, root, test.unsafePath, 0o711) - assert.ErrorIsf(t, err, test.expectedErr, "MkdirAll(%q, %q)", root, test.unsafePath) - - remainingPath = filepath.Join("/", remainingPath) - for remainingPath != filepath.Dir(remainingPath) { - stat, err := fstatatFile(handle, "./"+remainingPath, unix.AT_SYMLINK_NOFOLLOW) - if test.expectedErr == nil { - // Check that the new components have the right - // mode. - if assert.NoErrorf(t, err, "unexpected error when checking new directory %q", remainingPath) { - assert.Equalf(t, uint32(unix.S_IFDIR|expectedMode), stat.Mode, "new directory %q has the wrong mode", remainingPath) - } - } else { - // Check that none of the components are - // directories (i.e. make sure that the MkdirAll - // was a no-op). - if err == nil { - assert.NotEqualf(t, uint32(unix.S_IFDIR), stat.Mode&unix.S_IFMT, "failed MkdirAll created a new directory at %q", remainingPath) - } - } - // Jump up a level. - remainingPath = filepath.Dir(remainingPath) - } + const mode = 0o711 + checkMkdirAll(t, mkdirAll, root, test.unsafePath, mode, test.expectedModeBits|mode, test.expectedErr) }) } }) } func TestMkdirAll_Basic(t *testing.T) { - testMkdirAll_Basic(t, func(t *testing.T, root, unsafePath string, mode int) error { - // We can't check expectedPath here. - return MkdirAll(root, unsafePath, mode) - }) + testMkdirAll_Basic(t, mkdirAll_MkdirAll) } func TestMkdirAllHandle_Basic(t *testing.T) { - testMkdirAll_Basic(t, func(t *testing.T, root, unsafePath string, mode int) error { - // Same logic as MkdirAll. - rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) - if err != nil { - return err - } - defer rootDir.Close() - handle, err := MkdirAllHandle(rootDir, unsafePath, mode) - if err != nil { - return err - } - defer handle.Close() + testMkdirAll_Basic(t, mkdirAll_MkdirAllHandle) +} - // We can use SecureJoin here becuase we aren't being attacked in this - // particular test. Obviously this check is bogus for actual programs. - expectedPath, err := SecureJoin(root, unsafePath) - require.NoError(t, err) - - // Now double-check that the handle is correct. - gotPath, err := procSelfFdReadlink(handle) - require.NoError(t, err, "get real path of returned handle") - assert.Equal(t, expectedPath, gotPath, "wrong final path from MkdirAllHandle") - // Also check that the f.Name() is correct while we're at it (this is - // not always guaranteed but it's better to try at least). - assert.Equal(t, expectedPath, handle.Name(), "handle from MkdirAllHandle has the wrong .Name()") - return nil +func testMkdirAll_AsRoot(t *testing.T, mkdirAll mkdirAllFunc) { + requireRoot(t) // chown + + // We create a new tree for each test, but the template is the same. + tree := []string{ + // S_ISGID directories. + "dir sgid-self ::2755", + "dir sgid-other 1000:1000:2755", + "dir sgid-sticky-self ::3755", + "dir sgid-sticky-other 1000:1000:3755", + } + + withWithoutOpenat2(t, true, func(t *testing.T) { + for name, test := range map[string]struct { + unsafePath string + expectedErr error + expectedModeBits int + }{ + // Make sure the S_ISGID handling is correct. + "sgid-dir-ownedbyus": {unsafePath: "sgid-self/foo/bar/baz", expectedModeBits: unix.S_ISGID}, + "sgid-dir-ownedbyother": {unsafePath: "sgid-other/foo/bar/baz", expectedModeBits: unix.S_ISGID}, + "sgid-sticky-dir-ownedbyus": {unsafePath: "sgid-sticky-self/foo/bar/baz", expectedModeBits: unix.S_ISGID}, + "sgid-sticky-dir-ownedbyother": {unsafePath: "sgid-sticky-other/foo/bar/baz", expectedModeBits: unix.S_ISGID}, + } { + test := test // copy iterator + t.Run(name, func(t *testing.T) { + root := createTree(t, tree...) + const mode = 0o711 + checkMkdirAll(t, mkdirAll, root, test.unsafePath, mode, test.expectedModeBits|mode, test.expectedErr) + }) + } }) } -func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, unsafePath string, mode int) error) { +func TestMkdirAll_AsRoot(t *testing.T) { + testMkdirAll_AsRoot(t, mkdirAll_MkdirAll) +} + +func TestMkdirAllHandle_AsRoot(t *testing.T) { + testMkdirAll_AsRoot(t, mkdirAll_MkdirAllHandle) +} + +func testMkdirAll_InvalidMode(t *testing.T, mkdirAll mkdirAllFunc) { for _, test := range []struct { mode int expectedErr error @@ -204,12 +258,11 @@ func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, un {unix.S_IFDIR | 0o777, errInvalidMode}, {unix.S_IFREG | 0o777, errInvalidMode}, {unix.S_IFIFO | 0o777, errInvalidMode}, - // suid/sgid bits are valid but you get an error because they don't get - // applied by mkdirat. - // TODO: Figure out if we want to allow this. - {unix.S_ISUID | 0o777, errPossibleAttack}, - {unix.S_ISGID | 0o777, errPossibleAttack}, - {unix.S_ISUID | unix.S_ISGID | unix.S_ISVTX | 0o777, errPossibleAttack}, + // suid/sgid bits are silently ignored by mkdirat and so we return an + // error explicitly. + {unix.S_ISUID | 0o777, errInvalidMode}, + {unix.S_ISGID | 0o777, errInvalidMode}, + {unix.S_ISUID | unix.S_ISGID | unix.S_ISVTX | 0o777, errInvalidMode}, // Proper sticky bit should work. {unix.S_ISVTX | 0o777, nil}, // Regular mode bits. @@ -223,25 +276,11 @@ func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, un } func TestMkdirAll_InvalidMode(t *testing.T) { - testMkdirAll_InvalidMode(t, func(t *testing.T, root, unsafePath string, mode int) error { - return MkdirAll(root, unsafePath, mode) - }) + testMkdirAll_InvalidMode(t, mkdirAll_MkdirAll) } func TestMkdirAllHandle_InvalidMode(t *testing.T) { - testMkdirAll_InvalidMode(t, func(t *testing.T, root, unsafePath string, mode int) error { - rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) - if err != nil { - return err - } - defer rootDir.Close() - handle, err := MkdirAllHandle(rootDir, unsafePath, mode) - if err != nil { - return err - } - _ = handle.Close() - return nil - }) + testMkdirAll_InvalidMode(t, mkdirAll_MkdirAllHandle) } type racingMkdirMeta struct { diff --git a/openat_linux.go b/openat_linux.go index 949fb5f..ac083f2 100644 --- a/openat_linux.go +++ b/openat_linux.go @@ -42,6 +42,10 @@ func fstatatFile(dir *os.File, path string, flags int) (unix.Stat_t, error) { return stat, nil } +func fstatFile(fd *os.File) (unix.Stat_t, error) { + return fstatatFile(fd, "", unix.AT_EMPTY_PATH) +} + func readlinkatFile(dir *os.File, path string) (string, error) { size := 4096 for {