diff --git a/libcontainer/cgroups/fs2/memory.go b/libcontainer/cgroups/fs2/memory.go index f1db078613e..a0c619b5c97 100644 --- a/libcontainer/cgroups/fs2/memory.go +++ b/libcontainer/cgroups/fs2/memory.go @@ -32,11 +32,16 @@ func numToStr(value int64) (ret string) { } func setMemory(dirPath string, cgroup *configs.Cgroup) error { - if val := numToStr(cgroup.Resources.MemorySwap); val != "" { + swap, err := cgroups.ConvertMemorySwapToCgroupV2Value(cgroup.Resources.MemorySwap, cgroup.Resources.Memory) + if err != nil { + return err + } + if val := numToStr(swap); val != "" { if err := fscommon.WriteFile(dirPath, "memory.swap.max", val); err != nil { return err } } + if val := numToStr(cgroup.Resources.Memory); val != "" { if err := fscommon.WriteFile(dirPath, "memory.max", val); err != nil { return err diff --git a/libcontainer/cgroups/systemd/unified_hierarchy.go b/libcontainer/cgroups/systemd/unified_hierarchy.go index 7fab68fc450..2dffc34e076 100644 --- a/libcontainer/cgroups/systemd/unified_hierarchy.go +++ b/libcontainer/cgroups/systemd/unified_hierarchy.go @@ -84,6 +84,15 @@ func (m *UnifiedManager) Apply(pid int) error { newProp("MemoryMax", uint64(c.Resources.Memory))) } + swap, err := cgroups.ConvertMemorySwapToCgroupV2Value(c.Resources.MemorySwap, c.Resources.Memory) + if err != nil { + return err + } + if swap > 0 { + properties = append(properties, + newProp("MemorySwapMax", uint64(swap))) + } + if c.Resources.CpuWeight != 0 { properties = append(properties, newProp("CPUWeight", c.Resources.CpuWeight)) diff --git a/libcontainer/cgroups/utils.go b/libcontainer/cgroups/utils.go index 51b743e2ce3..8002c052999 100644 --- a/libcontainer/cgroups/utils.go +++ b/libcontainer/cgroups/utils.go @@ -623,3 +623,25 @@ func ConvertCPUQuotaCPUPeriodToCgroupV2Value(quota int64, period uint64) string } return fmt.Sprintf("%d %d", quota, period) } + +// ConvertMemorySwapToCgroupV2Value converts MemorySwap value from OCI spec +// for use by cgroup v2 drivers. A conversion is needed since Resources.MemorySwap +// is defined as memory+swap combined, while in cgroup v2 swap is a separate value. +func ConvertMemorySwapToCgroupV2Value(memorySwap, memory int64) (int64, error) { + if memorySwap == -1 || memorySwap == 0 { + // -1 is "max", 0 is "unset", so treat as is + return memorySwap, nil + } + // sanity checks + if memory == 0 || memory == -1 { + return 0, errors.New("unable to set swap limit without memory limit") + } + if memory < 0 { + return 0, fmt.Errorf("invalid memory value: %d", memory) + } + if memorySwap < memory { + return 0, errors.New("memory+swap limit should be > memory limit") + } + + return memorySwap - memory, nil +} diff --git a/libcontainer/cgroups/utils_test.go b/libcontainer/cgroups/utils_test.go index 14a76bf725b..adf266858bf 100644 --- a/libcontainer/cgroups/utils_test.go +++ b/libcontainer/cgroups/utils_test.go @@ -530,3 +530,84 @@ func TestConvertCPUQuotaCPUPeriodToCgroupV2Value(t *testing.T) { } } } + +func TestConvertMemorySwapToCgroupV2Value(t *testing.T) { + cases := []struct { + memswap, memory int64 + expected int64 + expErr bool + }{ + { + memswap: 0, + memory: 0, + expected: 0, + }, + { + memswap: -1, + memory: 0, + expected: -1, + }, + { + memswap: -1, + memory: -1, + expected: -1, + }, + { + memswap: -2, + memory: 0, + expErr: true, + }, + { + memswap: -1, + memory: 1000, + expected: -1, + }, + { + memswap: 1000, + memory: 1000, + expected: 0, + }, + { + memswap: 500, + memory: 200, + expected: 300, + }, + { + memswap: 300, + memory: 400, + expErr: true, + }, + { + memswap: 300, + memory: 0, + expErr: true, + }, + { + memswap: 300, + memory: -300, + expErr: true, + }, + { + memswap: 300, + memory: -1, + expErr: true, + }, + } + + for _, c := range cases { + swap, err := ConvertMemorySwapToCgroupV2Value(c.memswap, c.memory) + if c.expErr { + if err == nil { + t.Errorf("memswap: %d, memory %d, expected error, got %d, nil", c.memswap, c.memory, swap) + } + // no more checks + continue + } + if err != nil { + t.Errorf("memswap: %d, memory %d, expected success, got error %s", c.memswap, c.memory, err) + } + if swap != c.expected { + t.Errorf("memswap: %d, memory %d, expected %d, got %d", c.memswap, c.memory, c.expected, swap) + } + } +}