diff --git a/libcontainer/criu_linux.go b/libcontainer/criu_linux.go index fda16bead4c..4c6ae71465f 100644 --- a/libcontainer/criu_linux.go +++ b/libcontainer/criu_linux.go @@ -523,17 +523,7 @@ func (c *Container) restoreNetwork(req *criurpc.CriuReq, criuOpts *CriuOpts) { // restore using CRIU. This function is inspired from the code in // rootfs_linux.go. func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error { - me := mountEntry{Mount: m} - dest, err := securejoin.SecureJoin(c.config.Rootfs, m.Destination) - if err != nil { - return err - } - // TODO: pass srcFD? Not sure if criu is impacted by issue #2484. - if err := checkProcMount(c.config.Rootfs, dest, me); err != nil { - return err - } - switch m.Device { - case "cgroup": + if m.Device == "cgroup" { // No mount point(s) need to be created: // // * for v1, mount points are saved by CRIU because @@ -542,23 +532,12 @@ func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error { // * for v2, /sys/fs/cgroup is a real mount, but // the mountpoint appears as soon as /sys is mounted return nil - case "bind": - // For bind-mounts (unlike other filesystem types), we need to check if - // the source exists. - fi, _, err := me.srcStat() - if err != nil { - // error out if the source of a bind mount does not exist as we - // will be unable to bind anything to it. - return err - } - if err := createIfNotExists(dest, fi.IsDir()); err != nil { - return err - } - default: - // for all other filesystems just create the mountpoints - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } + } + // TODO: pass srcFD? Not sure if criu is impacted by issue #2484. + me := mountEntry{Mount: m} + // For all other filesystems, just make the target. + if _, err := createMountpoint(c.config.Rootfs, me); err != nil { + return fmt.Errorf("create criu restore mountpoint for %s mount: %w", me.Destination, err) } return nil } diff --git a/libcontainer/rootfs_linux.go b/libcontainer/rootfs_linux.go index 348d30fefe8..16878274cae 100644 --- a/libcontainer/rootfs_linux.go +++ b/libcontainer/rootfs_linux.go @@ -308,7 +308,11 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error { for _, b := range binds { if c.cgroupns { + // We just created the tmpfs, and so we can just use filepath.Join + // here (not to mention we want to make sure we create the path + // inside the tmpfs, so we don't want to resolve symlinks). subsystemPath := filepath.Join(c.root, b.Destination) + subsystemName := filepath.Base(b.Destination) if err := os.MkdirAll(subsystemPath, 0o755); err != nil { return err } @@ -319,7 +323,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error { } var ( source = "cgroup" - data = filepath.Base(subsystemPath) + data = subsystemName ) if data == "systemd" { data = cgroups.CgroupNamePrefix + data @@ -349,14 +353,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error { } func mountCgroupV2(m *configs.Mount, c *mountConfig) error { - dest, err := securejoin.SecureJoin(c.root, m.Destination) - if err != nil { - return err - } - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } - err = utils.WithProcfd(c.root, m.Destination, func(dstFd string) error { + err := utils.WithProcfd(c.root, m.Destination, func(dstFd string) error { return mountViaFds(m.Source, nil, m.Destination, dstFd, "cgroup2", uintptr(m.Flags), m.Data) }) if err == nil || !(errors.Is(err, unix.EPERM) || errors.Is(err, unix.EBUSY)) { @@ -482,6 +479,65 @@ func statfsToMountFlags(st unix.Statfs_t) int { return flags } +var errRootfsToFile = errors.New("config tries to change rootfs to file") + +func createMountpoint(rootfs string, m mountEntry) (string, error) { + dest, err := securejoin.SecureJoin(rootfs, m.Destination) + if err != nil { + return "", err + } + if err := checkProcMount(rootfs, dest, m); err != nil { + return "", fmt.Errorf("check proc-safety of %s mount: %w", m.Destination, err) + } + + switch m.Device { + case "bind": + fi, _, err := m.srcStat() + if err != nil { + // Error out if the source of a bind mount does not exist as we + // will be unable to bind anything to it. + return "", err + } + // If the original source is not a directory, make the target a file. + if !fi.IsDir() { + // Make sure we aren't tricked into trying to make the root a file. + if rootfs == dest { + return "", fmt.Errorf("%w: file bind mount over rootfs", errRootfsToFile) + } + // Make the parent directory. + if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { + return "", fmt.Errorf("make parent dir of file bind-mount: %w", err) + } + // Make the target file. + f, err := os.OpenFile(dest, os.O_CREATE, 0o755) + if err != nil { + return "", fmt.Errorf("create target of file bind-mount: %w", err) + } + _ = f.Close() + // Nothing left to do. + return dest, nil + } + + case "tmpfs": + // If the original target exists, copy the mode for the tmpfs mount. + if stat, err := os.Stat(dest); err == nil { + dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode())) + if m.Data != "" { + dt = dt + "," + m.Data + } + m.Data = dt + + // Nothing left to do. + return dest, nil + } + } + + if err := os.MkdirAll(dest, 0o755); err != nil { + return "", err + } + return dest, nil +} + func mountToRootfs(c *mountConfig, m mountEntry) error { rootfs := c.root @@ -495,7 +551,7 @@ func mountToRootfs(c *mountConfig, m mountEntry) error { // TODO: This won't be necessary once we switch to libpathrs and we can // stop all of these symlink-exchange attacks. dest := filepath.Clean(m.Destination) - if !strings.HasPrefix(dest, rootfs) { + if !utils.IsLexicallyInRoot(rootfs, dest) { // Do not use securejoin as it resolves symlinks. dest = filepath.Join(rootfs, dest) } @@ -516,37 +572,19 @@ func mountToRootfs(c *mountConfig, m mountEntry) error { return mountPropagate(m, rootfs, "") } - mountLabel := c.label - dest, err := securejoin.SecureJoin(rootfs, m.Destination) + dest, err := createMountpoint(rootfs, m) if err != nil { - return err - } - if err := checkProcMount(rootfs, dest, m); err != nil { - return err + return fmt.Errorf("create mountpoint for %s mount: %w", m.Destination, err) } + mountLabel := c.label switch m.Device { case "mqueue": - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } if err := mountPropagate(m, rootfs, ""); err != nil { return err } return label.SetFileLabel(dest, mountLabel) case "tmpfs": - if stat, err := os.Stat(dest); err != nil { - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } - } else { - dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode())) - if m.Data != "" { - dt = dt + "," + m.Data - } - m.Data = dt - } - if m.Extensions&configs.EXT_COPYUP == configs.EXT_COPYUP { err = doTmpfsCopyUp(m, rootfs, mountLabel) } else { @@ -555,15 +593,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error { return err case "bind": - fi, _, err := m.srcStat() - if err != nil { - // error out if the source of a bind mount does not exist as we will be - // unable to bind anything to it. - return err - } - if err := createIfNotExists(dest, fi.IsDir()); err != nil { - return err - } // open_tree()-related shenanigans are all handled in mountViaFds. if err := mountPropagate(m, rootfs, mountLabel); err != nil { return err @@ -679,9 +708,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error { } return mountCgroupV1(m.Mount, c) default: - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } return mountPropagate(m, rootfs, mountLabel) } } @@ -899,6 +925,9 @@ func createDeviceNode(rootfs string, node *devices.Device, bind bool) error { if err != nil { return err } + if dest == rootfs { + return fmt.Errorf("%w: mknod over rootfs", errRootfsToFile) + } if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { return err } @@ -1169,26 +1198,6 @@ func chroot() error { return nil } -// createIfNotExists creates a file or a directory only if it does not already exist. -func createIfNotExists(path string, isDir bool) error { - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - if isDir { - return os.MkdirAll(path, 0o755) - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - f, err := os.OpenFile(path, os.O_CREATE, 0o755) - if err != nil { - return err - } - _ = f.Close() - } - } - return nil -} - // readonlyPath will make a path read only. func readonlyPath(path string) error { if err := mount(path, path, "", unix.MS_BIND|unix.MS_REC, ""); err != nil { diff --git a/libcontainer/utils/utils_unix.go b/libcontainer/utils/utils_unix.go index 7b0acfa833a..6bf9102f414 100644 --- a/libcontainer/utils/utils_unix.go +++ b/libcontainer/utils/utils_unix.go @@ -9,6 +9,7 @@ import ( "path/filepath" "runtime" "strconv" + "strings" "sync" _ "unsafe" // for go:linkname @@ -260,3 +261,17 @@ func ProcThreadSelf(subpath string) (string, ProcThreadSelfCloser) { func ProcThreadSelfFd(fd uintptr) (string, ProcThreadSelfCloser) { return ProcThreadSelf("fd/" + strconv.FormatUint(uint64(fd), 10)) } + +// IsLexicallyInRoot is shorthand for strings.HasPrefix(path+"/", root+"/"), +// but properly handling the case where path or root are "/". +// +// NOTE: The return value only make sense if the path doesn't contain "..". +func IsLexicallyInRoot(root, path string) bool { + if root != "/" { + root += "/" + } + if path != "/" { + path += "/" + } + return strings.HasPrefix(path, root) +}