From 8e42696cf481cc61ba52af8221cff0b0c38106ea Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Fri, 27 Oct 2023 15:04:09 +0200 Subject: [PATCH] Rework MPS limit normalization With this change we always specify limits in terms of UUIDs when passing these to the MPS control daemon. We also check for valid indices. Signed-off-by: Evan Lezar --- .../resource/gpu/nas/v1alpha1/sharing.go | 101 ++++++++++++++---- .../resource/gpu/nas/v1alpha1/sharing_test.go | 93 ++++++++++++++-- 2 files changed, 164 insertions(+), 30 deletions(-) diff --git a/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing.go b/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing.go index f4b75ee1..e052f411 100644 --- a/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing.go +++ b/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing.go @@ -17,6 +17,7 @@ package v1alpha1 import ( + "errors" "fmt" "strconv" @@ -185,37 +186,95 @@ func (c TimeSliceDuration) Int() int { return -1 } -// TODO: Always return a map of UUID -> limit +// ErrInvalidDeviceSelector indicates that a device index or UUID was invalid. +var ErrInvalidDeviceSelector error = errors.New("invalid device") + +// ErrInvalidLimit indicates that a limit was invalid. +var ErrInvalidLimit error = errors.New("invalid limit") + // Normalize converts the specified per-device pinned memory limits to limits for the devices that are to be allocated. // If provided, the defaultPinnedDeviceMemoryLimit is applied to each device before being overridden by specific values. func (m MpsPerDevicePinnedMemoryLimit) Normalize(uuids []string, defaultPinnedDeviceMemoryLimit *resource.Quantity) (map[string]string, error) { - limits := make(map[string]string) - - // We set the defaults for all expected devices. - if v := defaultPinnedDeviceMemoryLimit; v != nil { - value := v.Value() / 1024 / 1024 - if value == 0 { - return nil, fmt.Errorf("default value set too low: %v", v) - } - for i := range uuids { - limits[fmt.Sprintf("%d", i)] = fmt.Sprintf("%vM", value) - } + limits, err := (*limit)(defaultPinnedDeviceMemoryLimit).get(uuids) + if err != nil { + return nil, err } + devices := newUUIDSet(uuids) for k, v := range m { - // TODO: This has to be an integer or a UUID - // TODO: Check that k is valid for the list of UUIDs. e.g. can't be greater than the length - _, err := strconv.Atoi(k) + id, err := devices.Normalize(k) if err != nil { - return nil, fmt.Errorf("unable to parse key as an integer: %v", k) + return nil, err } - - value := v.Value() / 1024 / 1024 - if value == 0 { - return nil, fmt.Errorf("value set too low: %v: %v", k, v) + megabyte, valid := (limit)(v).Megabyte() + if !valid { + return nil, fmt.Errorf("%w: value set too low: %v: %v", ErrInvalidLimit, k, v) } + limits[id] = megabyte + } + return limits, nil +} + +type limit resource.Quantity + +func (d *limit) get(uuids []string) (map[string]string, error) { + limits := make(map[string]string) + if d == nil || len(uuids) == 0 { + return limits, nil + } - limits[k] = fmt.Sprintf("%vM", value) + megabyte, valid := d.Megabyte() + if !valid { + return nil, fmt.Errorf("%w: default value set too low: %v", ErrInvalidLimit, d) } + for _, uuid := range uuids { + limits[uuid] = megabyte + } + return limits, nil } + +func (d limit) Value() int64 { + return (*resource.Quantity)(&d).Value() +} + +func (d limit) Megabyte() (string, bool) { + v := d.Value() / 1024 / 1024 + return fmt.Sprintf("%vM", v), v > 0 +} + +type uuidSet struct { + uuids []string + lookup map[string]bool +} + +// newUUIDSet creates a set of UUIDs for managing pinned memory for requested devices. +func newUUIDSet(uuids []string) *uuidSet { + lookup := make(map[string]bool) + for _, uuid := range uuids { + lookup[uuid] = true + } + + return &uuidSet{ + uuids: uuids, + lookup: lookup, + } +} + +func (s *uuidSet) Normalize(key string) (string, error) { + // Check whether key is a UUID + if _, ok := s.lookup[key]; ok { + return key, nil + } + + index, err := strconv.Atoi(key) + if err != nil { + return "", fmt.Errorf("%w: unable to parse key as an integer: %v", ErrInvalidDeviceSelector, key) + } + + if index >= 0 && index < len(s.uuids) { + return s.uuids[index], nil + } + + return "", fmt.Errorf("%w: invalid device index: %v", ErrInvalidDeviceSelector, index) +} diff --git a/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing_test.go b/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing_test.go index 51a31c2c..5722f172 100644 --- a/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing_test.go +++ b/api/nvidia.com/resource/gpu/nas/v1alpha1/sharing_test.go @@ -35,13 +35,15 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) { expectedLimits map[string]string }{ { - description: "no uuids, no default", + description: "empty input", + expectedLimits: map[string]string{}, + }, + { + description: "no uuids, invalid device index", perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{ "0": resource.MustParse("1Gi"), }, - expectedLimits: map[string]string{ - "0": "1024M", - }, + expectedError: v1alpha1.ErrInvalidDeviceSelector, }, { description: "no uuids, default is overridden", @@ -49,17 +51,29 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) { perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{ "0": resource.MustParse("1Gi"), }, - expectedLimits: map[string]string{ - "0": "1024M", - }, + expectedError: v1alpha1.ErrInvalidDeviceSelector, }, { description: "uuids, default is set", uuids: []string{"UUID0"}, memoryLimit: ptr(resource.MustParse("2Gi")), expectedLimits: map[string]string{ - "0": "2048M", + "UUID0": "2048M", + }, + }, + { + description: "uuids, default is too low", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("1M")), + expectedError: v1alpha1.ErrInvalidLimit, + }, + { + description: "uuids, override is too low", + uuids: []string{"UUID0"}, + perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{ + "UUID0": resource.MustParse("1M"), }, + expectedError: v1alpha1.ErrInvalidLimit, }, { description: "uuids, default is overridden", @@ -69,7 +83,68 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) { "0": resource.MustParse("1Gi"), }, expectedLimits: map[string]string{ - "0": "1024M", + "UUID0": "1024M", + }, + }, + { + description: "uuids, default is overridden by uuid", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("2Gi")), + perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{ + "UUID0": resource.MustParse("1Gi"), + }, + expectedLimits: map[string]string{ + "UUID0": "1024M", + }, + }, + { + description: "uuids, default is overridden, invalid UUID", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("2Gi")), + perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{ + "UUID1": resource.MustParse("1Gi"), + }, + expectedError: v1alpha1.ErrInvalidDeviceSelector, + }, + { + description: "uuids, default is overridden, invalid index", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("2Gi")), + perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{ + "1": resource.MustParse("1Gi"), + }, + expectedError: v1alpha1.ErrInvalidDeviceSelector, + }, + { + description: "unit conversion Mi to M", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("10Mi")), + expectedLimits: map[string]string{ + "UUID0": "10M", + }, + }, + { + description: "unit conversion Gi to M", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("1Gi")), + expectedLimits: map[string]string{ + "UUID0": "1024M", + }, + }, + { + description: "unit conversion M to M", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("10M")), + expectedLimits: map[string]string{ + "UUID0": "9M", + }, + }, + { + description: "unit conversion G to M", + uuids: []string{"UUID0"}, + memoryLimit: ptr(resource.MustParse("1G")), + expectedLimits: map[string]string{ + "UUID0": "953M", }, }, }