Skip to content

Commit

Permalink
Rework MPS limit normalization
Browse files Browse the repository at this point in the history
With this change we always specify limits in terms of UUIDs
when passing these to the MPS control daemon. We also check for
valid indices.

Signed-off-by: Evan Lezar <[email protected]>
  • Loading branch information
elezar committed Feb 29, 2024
1 parent 11d749c commit 8e42696
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 30 deletions.
101 changes: 80 additions & 21 deletions api/nvidia.com/resource/gpu/nas/v1alpha1/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package v1alpha1

import (
"errors"
"fmt"
"strconv"

Expand Down Expand Up @@ -185,37 +186,95 @@ func (c TimeSliceDuration) Int() int {
return -1
}

// TODO: Always return a map of UUID -> limit
// ErrInvalidDeviceSelector indicates that a device index or UUID was invalid.
var ErrInvalidDeviceSelector error = errors.New("invalid device")

// ErrInvalidLimit indicates that a limit was invalid.
var ErrInvalidLimit error = errors.New("invalid limit")

// Normalize converts the specified per-device pinned memory limits to limits for the devices that are to be allocated.
// If provided, the defaultPinnedDeviceMemoryLimit is applied to each device before being overridden by specific values.
func (m MpsPerDevicePinnedMemoryLimit) Normalize(uuids []string, defaultPinnedDeviceMemoryLimit *resource.Quantity) (map[string]string, error) {
limits := make(map[string]string)

// We set the defaults for all expected devices.
if v := defaultPinnedDeviceMemoryLimit; v != nil {
value := v.Value() / 1024 / 1024
if value == 0 {
return nil, fmt.Errorf("default value set too low: %v", v)
}
for i := range uuids {
limits[fmt.Sprintf("%d", i)] = fmt.Sprintf("%vM", value)
}
limits, err := (*limit)(defaultPinnedDeviceMemoryLimit).get(uuids)
if err != nil {
return nil, err
}

devices := newUUIDSet(uuids)
for k, v := range m {
// TODO: This has to be an integer or a UUID
// TODO: Check that k is valid for the list of UUIDs. e.g. can't be greater than the length
_, err := strconv.Atoi(k)
id, err := devices.Normalize(k)
if err != nil {
return nil, fmt.Errorf("unable to parse key as an integer: %v", k)
return nil, err
}

value := v.Value() / 1024 / 1024
if value == 0 {
return nil, fmt.Errorf("value set too low: %v: %v", k, v)
megabyte, valid := (limit)(v).Megabyte()
if !valid {
return nil, fmt.Errorf("%w: value set too low: %v: %v", ErrInvalidLimit, k, v)
}
limits[id] = megabyte
}
return limits, nil
}

type limit resource.Quantity

func (d *limit) get(uuids []string) (map[string]string, error) {
limits := make(map[string]string)
if d == nil || len(uuids) == 0 {
return limits, nil
}

limits[k] = fmt.Sprintf("%vM", value)
megabyte, valid := d.Megabyte()
if !valid {
return nil, fmt.Errorf("%w: default value set too low: %v", ErrInvalidLimit, d)
}
for _, uuid := range uuids {
limits[uuid] = megabyte
}

return limits, nil
}

func (d limit) Value() int64 {
return (*resource.Quantity)(&d).Value()
}

func (d limit) Megabyte() (string, bool) {
v := d.Value() / 1024 / 1024
return fmt.Sprintf("%vM", v), v > 0
}

type uuidSet struct {
uuids []string
lookup map[string]bool
}

// newUUIDSet creates a set of UUIDs for managing pinned memory for requested devices.
func newUUIDSet(uuids []string) *uuidSet {
lookup := make(map[string]bool)
for _, uuid := range uuids {
lookup[uuid] = true
}

return &uuidSet{
uuids: uuids,
lookup: lookup,
}
}

func (s *uuidSet) Normalize(key string) (string, error) {
// Check whether key is a UUID
if _, ok := s.lookup[key]; ok {
return key, nil
}

index, err := strconv.Atoi(key)
if err != nil {
return "", fmt.Errorf("%w: unable to parse key as an integer: %v", ErrInvalidDeviceSelector, key)
}

if index >= 0 && index < len(s.uuids) {
return s.uuids[index], nil
}

return "", fmt.Errorf("%w: invalid device index: %v", ErrInvalidDeviceSelector, index)
}
93 changes: 84 additions & 9 deletions api/nvidia.com/resource/gpu/nas/v1alpha1/sharing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,45 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) {
expectedLimits map[string]string
}{
{
description: "no uuids, no default",
description: "empty input",
expectedLimits: map[string]string{},
},
{
description: "no uuids, invalid device index",
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"0": "1024M",
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "no uuids, default is overridden",
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"0": "1024M",
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "uuids, default is set",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
expectedLimits: map[string]string{
"0": "2048M",
"UUID0": "2048M",
},
},
{
description: "uuids, default is too low",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("1M")),
expectedError: v1alpha1.ErrInvalidLimit,
},
{
description: "uuids, override is too low",
uuids: []string{"UUID0"},
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"UUID0": resource.MustParse("1M"),
},
expectedError: v1alpha1.ErrInvalidLimit,
},
{
description: "uuids, default is overridden",
Expand All @@ -69,7 +83,68 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) {
"0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"0": "1024M",
"UUID0": "1024M",
},
},
{
description: "uuids, default is overridden by uuid",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"UUID0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"UUID0": "1024M",
},
},
{
description: "uuids, default is overridden, invalid UUID",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"UUID1": resource.MustParse("1Gi"),
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "uuids, default is overridden, invalid index",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"1": resource.MustParse("1Gi"),
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "unit conversion Mi to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("10Mi")),
expectedLimits: map[string]string{
"UUID0": "10M",
},
},
{
description: "unit conversion Gi to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("1Gi")),
expectedLimits: map[string]string{
"UUID0": "1024M",
},
},
{
description: "unit conversion M to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("10M")),
expectedLimits: map[string]string{
"UUID0": "9M",
},
},
{
description: "unit conversion G to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("1G")),
expectedLimits: map[string]string{
"UUID0": "953M",
},
},
}
Expand Down

0 comments on commit 8e42696

Please sign in to comment.