Skip to content

Commit

Permalink
rootfs: consolidate mountpoint creation logic
Browse files Browse the repository at this point in the history
The logic for how we create mountpoints is spread over each mountpoint
preparation function, when in reality the behaviour is pretty uniform
with only a handful of exceptions. So just move it all to one function
that is easier to understand.

Signed-off-by: Aleksa Sarai <[email protected]>
  • Loading branch information
cyphar committed Jul 25, 2024
1 parent 3778ae6 commit 1410a69
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 91 deletions.
35 changes: 7 additions & 28 deletions libcontainer/criu_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,7 @@ func (c *Container) restoreNetwork(req *criurpc.CriuReq, criuOpts *CriuOpts) {
// restore using CRIU. This function is inspired from the code in
// rootfs_linux.go.
func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error {
me := mountEntry{Mount: m}
dest, err := securejoin.SecureJoin(c.config.Rootfs, m.Destination)
if err != nil {
return err
}
// TODO: pass srcFD? Not sure if criu is impacted by issue #2484.
if err := checkProcMount(c.config.Rootfs, dest, me); err != nil {
return err
}
switch m.Device {
case "cgroup":
if m.Device == "cgroup" {
// No mount point(s) need to be created:
//
// * for v1, mount points are saved by CRIU because
Expand All @@ -542,23 +532,12 @@ func (c *Container) makeCriuRestoreMountpoints(m *configs.Mount) error {
// * for v2, /sys/fs/cgroup is a real mount, but
// the mountpoint appears as soon as /sys is mounted
return nil
case "bind":
// For bind-mounts (unlike other filesystem types), we need to check if
// the source exists.
fi, _, err := me.srcStat()
if err != nil {
// error out if the source of a bind mount does not exist as we
// will be unable to bind anything to it.
return err
}
if err := createIfNotExists(dest, fi.IsDir()); err != nil {
return err
}
default:
// for all other filesystems just create the mountpoints
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
}
// TODO: pass srcFD? Not sure if criu is impacted by issue #2484.
me := mountEntry{Mount: m}
// For all other filesystems, just make the target.
if _, err := createMountpoint(c.config.Rootfs, me); err != nil {
return fmt.Errorf("create criu restore mountpoint for %s mount: %w", me.Destination, err)
}
return nil
}
Expand Down
135 changes: 72 additions & 63 deletions libcontainer/rootfs_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,11 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {

for _, b := range binds {
if c.cgroupns {
// We just created the tmpfs, and so we can just use filepath.Join
// here (not to mention we want to make sure we create the path
// inside the tmpfs, so we don't want to resolve symlinks).
subsystemPath := filepath.Join(c.root, b.Destination)
subsystemName := filepath.Base(b.Destination)
if err := os.MkdirAll(subsystemPath, 0o755); err != nil {
return err
}
Expand All @@ -319,7 +323,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {
}
var (
source = "cgroup"
data = filepath.Base(subsystemPath)
data = subsystemName
)
if data == "systemd" {
data = cgroups.CgroupNamePrefix + data
Expand Down Expand Up @@ -349,14 +353,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error {
}

func mountCgroupV2(m *configs.Mount, c *mountConfig) error {
dest, err := securejoin.SecureJoin(c.root, m.Destination)
if err != nil {
return err
}
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
err = utils.WithProcfd(c.root, m.Destination, func(dstFd string) error {
err := utils.WithProcfd(c.root, m.Destination, func(dstFd string) error {
return mountViaFds(m.Source, nil, m.Destination, dstFd, "cgroup2", uintptr(m.Flags), m.Data)
})
if err == nil || !(errors.Is(err, unix.EPERM) || errors.Is(err, unix.EBUSY)) {
Expand Down Expand Up @@ -482,6 +479,65 @@ func statfsToMountFlags(st unix.Statfs_t) int {
return flags
}

var errRootfsToFile = errors.New("config tries to change rootfs to file")

func createMountpoint(rootfs string, m mountEntry) (string, error) {
dest, err := securejoin.SecureJoin(rootfs, m.Destination)
if err != nil {
return "", err
}
if err := checkProcMount(rootfs, dest, m); err != nil {
return "", fmt.Errorf("check proc-safety of %s mount: %w", m.Destination, err)
}

switch m.Device {
case "bind":
fi, _, err := m.srcStat()
if err != nil {
// Error out if the source of a bind mount does not exist as we
// will be unable to bind anything to it.
return "", err
}
// If the original source is not a directory, make the target a file.
if !fi.IsDir() {
// Make sure we aren't tricked into trying to make the root a file.
if rootfs == dest {
return "", fmt.Errorf("%w: file bind mount over rootfs", errRootfsToFile)
}
// Make the parent directory.
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return "", fmt.Errorf("make parent dir of file bind-mount: %w", err)
}
// Make the target file.
f, err := os.OpenFile(dest, os.O_CREATE, 0o755)
if err != nil {
return "", fmt.Errorf("create target of file bind-mount: %w", err)
}
_ = f.Close()
// Nothing left to do.
return dest, nil
}

case "tmpfs":
// If the original target exists, copy the mode for the tmpfs mount.
if stat, err := os.Stat(dest); err == nil {
dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode()))
if m.Data != "" {
dt = dt + "," + m.Data
}
m.Data = dt

// Nothing left to do.
return dest, nil
}
}

if err := os.MkdirAll(dest, 0o755); err != nil {
return "", err
}
return dest, nil
}

func mountToRootfs(c *mountConfig, m mountEntry) error {
rootfs := c.root

Expand All @@ -495,7 +551,7 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
// TODO: This won't be necessary once we switch to libpathrs and we can
// stop all of these symlink-exchange attacks.
dest := filepath.Clean(m.Destination)
if !strings.HasPrefix(dest, rootfs) {
if !utils.IsLexicallyInRoot(rootfs, dest) {
// Do not use securejoin as it resolves symlinks.
dest = filepath.Join(rootfs, dest)
}
Expand All @@ -516,37 +572,19 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
return mountPropagate(m, rootfs, "")
}

mountLabel := c.label
dest, err := securejoin.SecureJoin(rootfs, m.Destination)
dest, err := createMountpoint(rootfs, m)
if err != nil {
return err
}
if err := checkProcMount(rootfs, dest, m); err != nil {
return err
return fmt.Errorf("create mountpoint for %s mount: %w", m.Destination, err)
}
mountLabel := c.label

switch m.Device {
case "mqueue":
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
if err := mountPropagate(m, rootfs, ""); err != nil {
return err
}
return label.SetFileLabel(dest, mountLabel)
case "tmpfs":
if stat, err := os.Stat(dest); err != nil {
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
} else {
dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode()))
if m.Data != "" {
dt = dt + "," + m.Data
}
m.Data = dt
}

if m.Extensions&configs.EXT_COPYUP == configs.EXT_COPYUP {
err = doTmpfsCopyUp(m, rootfs, mountLabel)
} else {
Expand All @@ -555,15 +593,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {

return err
case "bind":
fi, _, err := m.srcStat()
if err != nil {
// error out if the source of a bind mount does not exist as we will be
// unable to bind anything to it.
return err
}
if err := createIfNotExists(dest, fi.IsDir()); err != nil {
return err
}
// open_tree()-related shenanigans are all handled in mountViaFds.
if err := mountPropagate(m, rootfs, mountLabel); err != nil {
return err
Expand Down Expand Up @@ -679,9 +708,6 @@ func mountToRootfs(c *mountConfig, m mountEntry) error {
}
return mountCgroupV1(m.Mount, c)
default:
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
return mountPropagate(m, rootfs, mountLabel)
}
}
Expand Down Expand Up @@ -899,6 +925,9 @@ func createDeviceNode(rootfs string, node *devices.Device, bind bool) error {
if err != nil {
return err
}
if dest == rootfs {
return fmt.Errorf("%w: mknod over rootfs", errRootfsToFile)
}
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return err
}
Expand Down Expand Up @@ -1169,26 +1198,6 @@ func chroot() error {
return nil
}

// createIfNotExists creates a file or a directory only if it does not already exist.
func createIfNotExists(path string, isDir bool) error {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
if isDir {
return os.MkdirAll(path, 0o755)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE, 0o755)
if err != nil {
return err
}
_ = f.Close()
}
}
return nil
}

// readonlyPath will make a path read only.
func readonlyPath(path string) error {
if err := mount(path, path, "", unix.MS_BIND|unix.MS_REC, ""); err != nil {
Expand Down
15 changes: 15 additions & 0 deletions libcontainer/utils/utils_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
_ "unsafe" // for go:linkname

Expand Down Expand Up @@ -260,3 +261,17 @@ func ProcThreadSelf(subpath string) (string, ProcThreadSelfCloser) {
func ProcThreadSelfFd(fd uintptr) (string, ProcThreadSelfCloser) {
return ProcThreadSelf("fd/" + strconv.FormatUint(uint64(fd), 10))
}

// IsLexicallyInRoot is shorthand for strings.HasPrefix(path+"/", root+"/"),
// but properly handling the case where path or root are "/".
//
// NOTE: The return value only make sense if the path doesn't contain "..".
func IsLexicallyInRoot(root, path string) bool {
if root != "/" {
root += "/"
}
if path != "/" {
path += "/"
}
return strings.HasPrefix(path, root)
}

0 comments on commit 1410a69

Please sign in to comment.