Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Ramachandra Sekar <[email protected]>
  • Loading branch information
Varun Ramachandra Sekar committed Dec 9, 2024
1 parent 5ce0763 commit f152735
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 98 deletions.
2 changes: 1 addition & 1 deletion api/nvidia.com/resource/gpu/v1alpha1/driverconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (c *GpuDriverConfig) Validate() error {
case VfioPciDriver:
break
default:
return fmt.Errorf("invalid driver specified in gpu driver configuration")
return fmt.Errorf("invalid driver '%s' specified in gpu driver configuration", c.Driver)
}
return nil
}
33 changes: 18 additions & 15 deletions api/nvidia.com/resource/gpu/v1alpha1/gpuconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ func DefaultGpuConfig() *GpuConfig {
func (c *GpuConfig) Normalize() error {
if c.DriverConfig == nil {
c.DriverConfig = DefaultGpuDriverConfig()
} else {
if err := c.DriverConfig.Normalize(); err != nil {
return err
}
}
// If driver is not Nvidia, don't proceed with normalizing sharing configuration.
if c.DriverConfig.Driver != NvidiaDriver {

if err := c.DriverConfig.Normalize(); err != nil {
return err
}

// If sharing is not supported, don't proceed with normalizing its configuration.
if !c.SupportsSharing() {
return nil
}

Expand All @@ -84,21 +85,23 @@ func (c *GpuConfig) Normalize() error {

// Validate ensures that GpuConfig has a valid set of values.
func (c *GpuConfig) Validate() error {
if c.DriverConfig.Driver == NvidiaDriver {
if err := c.DriverConfig.Validate(); err != nil {
return err
}

if c.SupportsSharing() {
if c.Sharing == nil {
return fmt.Errorf("no sharing strategy set")
}
if err := c.Sharing.Validate(); err != nil {
return err
}
} else {
if c.Sharing != nil {
return fmt.Errorf("sharing strategy cannot be provided while using non-nvidia driver")
}
}
if err := c.DriverConfig.Validate(); err != nil {
return err
} else if c.Sharing != nil {
return fmt.Errorf("sharing strategy cannot be provided while using non-nvidia driver")
}

return nil
}

func (c *GpuConfig) SupportsSharing() bool {
return c.DriverConfig.Driver == NvidiaDriver
}
1 change: 1 addition & 0 deletions cmd/nvidia-dra-plugin/allocatable.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,6 @@ func (d AllocatableDevices) PciAddresses() []string {
pciAddresses = append(pciAddresses, device.Gpu.PciAddress)
}
}
slices.Sort(pciAddresses)
return pciAddresses
}
107 changes: 51 additions & 56 deletions cmd/nvidia-dra-plugin/device_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type OpaqueDeviceConfig struct {

type DeviceConfigState struct {
MpsControlDaemonID string `json:"mpsControlDaemonID"`
GpuConfig *configapi.GpuConfig `json:"deviceConfig,omitempty"`
GpuConfig *configapi.GpuConfig `json:"gpuConfig,omitempty"`
containerEdits *cdiapi.ContainerEdits
}

Expand Down Expand Up @@ -112,7 +112,9 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
}

// Initialize the vfio-pci driver manager.
vfioPciManager.Init()
if err := vfioPciManager.Init(); err != nil {
return nil, fmt.Errorf("unable to initialize vfio-pci manager: %v", err)
}

checkpoints, err := state.checkpointManager.ListCheckpoints()
if err != nil {
Expand Down Expand Up @@ -357,12 +359,21 @@ func (s *DeviceState) prepareDevices(ctx context.Context, claim *resourceapi.Res

func (s *DeviceState) unprepareDevices(ctx context.Context, claimUID string, devices PreparedDevices) error {
for _, group := range devices {
var err error
if group.ConfigState.GpuConfig != nil {
err = s.unprepareGpus(ctx, group.ConfigState.GpuConfig, group.Devices.Gpus())
err := s.unprepareGpus(ctx, group.ConfigState.GpuConfig, group.Devices.Gpus())
if err != nil {
return err
}
}
if err != nil {
return err
// Stop any MPS control daemons started for each group of prepared devices.
mpsControlDaemon := s.mpsManager.NewMpsControlDaemon(claimUID, group)
if err := mpsControlDaemon.Stop(ctx); err != nil {
return fmt.Errorf("error stopping MPS control daemon: %w", err)
}
// Go back to default time-slicing for all full GPUs.
tsc := configapi.DefaultGpuConfig().Sharing.TimeSlicingConfig
if err := s.tsManager.SetTimeSlice(devices, tsc); err != nil {
return fmt.Errorf("error setting timeslice for devices: %w", err)
}
}
return nil
Expand All @@ -375,52 +386,40 @@ func (s *DeviceState) unprepareGpus(ctx context.Context, config *configapi.GpuCo
}
}
}
// Go back to default time-slicing for all full GPUs.
tsc := configapi.DefaultGpuConfig().Sharing.TimeSlicingConfig
if err := s.tsManager.SetTimeSlice(devices, tsc); err != nil {
return fmt.Errorf("error setting timeslice for devices: %w", err)
}
return nil
}

func (s *DeviceState) applyConfig(ctx context.Context, config configapi.Interface, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult) (*DeviceConfigState, error) {
var err error
var configState DeviceConfigState

switch castConfig := config.(type) {
case *configapi.GpuConfig:
configState.GpuConfig = castConfig
err = s.applyGpuConfig(ctx, castConfig, claim, results, &configState)
return s.applyGpuConfig(ctx, castConfig, claim, results, &configState)
case *configapi.MigDeviceConfig:
err = s.applySharingConfig(ctx, castConfig.Sharing, claim, results, &configState)
return s.applySharingConfig(ctx, castConfig.Sharing, claim, results, &configState)
case *configapi.ImexChannelConfig:
err = s.applyImexChannelConfig(ctx, castConfig, claim, results, &configState)
return s.applyImexChannelConfig(ctx, castConfig, claim, results, &configState)
default:
err = fmt.Errorf("unknown config type: %T", castConfig)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("unknown config type: %T", castConfig)
}
return &configState, nil
}

func (s *DeviceState) applyGpuConfig(ctx context.Context, config *configapi.GpuConfig, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) error {
if config.Sharing != nil {
err := s.applySharingConfig(ctx, config.Sharing, claim, results, configState)
if err != nil {
return err
}
func (s *DeviceState) applyGpuConfig(ctx context.Context, config *configapi.GpuConfig, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) (*DeviceConfigState, error) {
var err error
configState, err = s.applyGpuDriverConfig(ctx, config.DriverConfig, results, configState)
if err != nil {
return nil, err
}
if config.DriverConfig != nil {
err := s.applyGpuDriverConfig(ctx, config.DriverConfig, results, configState)
if config.SupportsSharing() {
configState, err = s.applySharingConfig(ctx, config.Sharing, claim, results, configState)
if err != nil {
return err
return nil, err
}
}
return nil
return configState, nil
}

func (s *DeviceState) applySharingConfig(ctx context.Context, config configapi.Sharing, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) error {
func (s *DeviceState) applySharingConfig(ctx context.Context, config configapi.Sharing, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) (*DeviceConfigState, error) {
// Get the list of claim requests this config is being applied over.
var requests []string
for _, r := range results {
Expand All @@ -437,12 +436,12 @@ func (s *DeviceState) applySharingConfig(ctx context.Context, config configapi.S
if config.IsTimeSlicing() {
tsc, err := config.GetTimeSlicingConfig()
if err != nil {
return fmt.Errorf("error getting timeslice config for requests '%v' in claim '%v': %w", requests, claim.UID, err)
return nil, fmt.Errorf("error getting timeslice config for requests '%v' in claim '%v': %w", requests, claim.UID, err)
}
if tsc != nil {
err = s.tsManager.SetTimeSlice(allocatableDevices, tsc)
if err != nil {
return fmt.Errorf("error setting timeslice config for requests '%v' in claim '%v': %w", requests, claim.UID, err)
return nil, fmt.Errorf("error setting timeslice config for requests '%v' in claim '%v': %w", requests, claim.UID, err)
}
}
}
Expand All @@ -451,55 +450,51 @@ func (s *DeviceState) applySharingConfig(ctx context.Context, config configapi.S
if config.IsMps() {
mpsc, err := config.GetMpsConfig()
if err != nil {
return fmt.Errorf("error getting MPS configuration: %w", err)
return nil, fmt.Errorf("error getting MPS configuration: %w", err)
}
mpsControlDaemon := s.mpsManager.NewMpsControlDaemon(string(claim.UID), allocatableDevices)
if err := mpsControlDaemon.Start(ctx, mpsc); err != nil {
return fmt.Errorf("error starting MPS control daemon: %w", err)
return nil, fmt.Errorf("error starting MPS control daemon: %w", err)
}
if err := mpsControlDaemon.AssertReady(ctx); err != nil {
return fmt.Errorf("MPS control daemon is not yet ready: %w", err)
return nil, fmt.Errorf("MPS control daemon is not yet ready: %w", err)
}
configState.MpsControlDaemonID = mpsControlDaemon.GetID()
configState.containerEdits = mpsControlDaemon.GetCDIContainerEdits()
}

return nil
return configState, nil
}

func (s *DeviceState) applyImexChannelConfig(ctx context.Context, config *configapi.ImexChannelConfig, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) error {
func (s *DeviceState) applyImexChannelConfig(ctx context.Context, config *configapi.ImexChannelConfig, claim *resourceapi.ResourceClaim, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) (*DeviceConfigState, error) {
// Create any necessary IMEX channels and gather their CDI container edits.
for _, r := range results {
imexChannel := s.allocatable[r.Device].ImexChannel
if err := s.nvdevlib.createImexChannelDevice(imexChannel.Channel); err != nil {
return fmt.Errorf("error creating IMEX channel device: %w", err)
return nil, fmt.Errorf("error creating IMEX channel device: %w", err)
}
configState.containerEdits = configState.containerEdits.Append(s.cdi.GetImexChannelContainerEdits(imexChannel))
}

return nil
return configState, nil
}

func (s *DeviceState) applyGpuDriverConfig(ctx context.Context, config *configapi.GpuDriverConfig, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) error {
// Get the list of allocatable devices this config is being applied over.
allocatableDevices := make(AllocatableDevices)
for _, r := range results {
allocatableDevices[r.Device] = s.allocatable[r.Device]
func (s *DeviceState) applyGpuDriverConfig(ctx context.Context, config *configapi.GpuDriverConfig, results []*resourceapi.DeviceRequestAllocationResult, configState *DeviceConfigState) (*DeviceConfigState, error) {
if config.Driver != configapi.VfioPciDriver {
return configState, nil
}

if config.Driver == configapi.VfioPciDriver {
// Apply vfio-pci driver settings.
for _, r := range results {
info := allocatableDevices[r.Device]
err := s.vfioPciManager.Configure(info.Gpu)
if err != nil {
return err
}
configState.containerEdits = configState.containerEdits.Append(s.vfioPciManager.GetCDIContainerEdits(info.Gpu))
// Apply vfio-pci driver settings.
for _, r := range results {
info := s.allocatable[r.Device]
err := s.vfioPciManager.Configure(info.Gpu)
if err != nil {
return nil, err
}
configState.containerEdits = configState.containerEdits.Append(s.vfioPciManager.GetCDIContainerEdits(info.Gpu))
}

return nil
return configState, nil
}

// GetOpaqueDeviceConfigs returns an ordered list of the configs contained in possibleConfigs for this driver.
Expand Down
2 changes: 1 addition & 1 deletion cmd/nvidia-dra-plugin/deviceinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

type GpuInfo struct {
UUID string `json:"uuid"`
PciAddress string `json:"pciAddress"`
index int
minor int
migEnabled bool
Expand All @@ -40,7 +41,6 @@ type GpuInfo struct {
driverVersion string
cudaDriverVersion string
migProfiles []*MigProfileInfo
PciAddress string `json:"pciAddress"`
}

type MigDeviceInfo struct {
Expand Down
3 changes: 0 additions & 3 deletions cmd/nvidia-dra-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ type Flags struct {
hostDriverRoot string
nvidiaCTKPath string
deviceClasses sets.Set[string]
pciDevicesRoot string
sysModulesRoot string
vfioDevicesRoot string
}

type Config struct {
Expand Down
12 changes: 3 additions & 9 deletions cmd/nvidia-dra-plugin/nvlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,6 @@ func (l deviceLib) enumerateImexChannels(config *Config) (AllocatableDevices, er
return devices, nil
}

func getPciAddressFromNvmlPciInfo(info nvml.PciInfo) string {
return fmt.Sprintf("%04x:%02x:%02x.0", info.Domain, info.Bus, info.Device)
}

func (l deviceLib) getGpuInfo(index int, device nvdev.Device) (*GpuInfo, error) {
minor, ret := device.GetMinorNumber()
if ret != nvml.SUCCESS {
Expand Down Expand Up @@ -244,12 +240,10 @@ func (l deviceLib) getGpuInfo(index int, device nvdev.Device) (*GpuInfo, error)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting CUDA driver version: %w", err)
}
pciInfo, ret := l.nvmllib.DeviceGetPciInfo(device)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting PCI info for device %d: %w", index, err)
pciAddress, err := device.GetPCIBusID()
if err != nil {
return nil, err
}
pciAddress := getPciAddressFromNvmlPciInfo(pciInfo)

var migProfiles []*MigProfileInfo
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := device.GetGpuInstanceProfileInfo(i)
Expand Down
18 changes: 18 additions & 0 deletions deployments/helm/k8s-dra-driver/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,21 @@ Filter a list by a set of valid values
{{- end }}
{{- $result -}}
{{- end -}}

{{- define "k8s-dra-driver.vfiopciDeviceClassVolumes" -}}
- name: sysfs
hostPath:
path: /sys
- name: dev-vfio
hostPath:
path: /dev/vfio
{{- end -}}

{{- define "k8s-dra-driver.vfiopciDeviceClassVolumeMounts" -}}
- name: sysfs
mountPath: /sys
readOnly: false
- name: dev-vfio
mountPath: /dev/vfio
readOnly: false
{{- end -}}
18 changes: 6 additions & 12 deletions deployments/helm/k8s-dra-driver/templates/kubeletplugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,9 @@ spec:
- name: driver-root
mountPath: /driver-root
readOnly: true
- name: sysfs
mountPath: /sys
readOnly: false
- name: dev-vfio
mountPath: /dev/vfio
readOnly: false
{{- if include "k8s-dra-driver.listHas" (list $.Values.deviceClasses "vfiopci") }}
{{- include "k8s-dra-driver.vfiopciDeviceClassVolumeMounts" . | nindent 8 }}
{{- end }}
volumes:
- name: plugins-registry
hostPath:
Expand All @@ -122,12 +119,9 @@ spec:
- name: driver-root
hostPath:
path: {{ .Values.nvidiaDriverRoot }}
- name: sysfs
hostPath:
path: /sys
- name: dev-vfio
hostPath:
path: /dev/vfio
{{- if include "k8s-dra-driver.listHas" (list $.Values.deviceClasses "vfiopci") }}
{{- include "k8s-dra-driver.vfiopciDeviceClassVolumes" . | nindent 6}}
{{- end }}
{{- with .Values.kubeletPlugin.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
Expand Down
2 changes: 1 addition & 1 deletion deployments/helm/k8s-dra-driver/templates/validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

{{- $validDeviceClasses := list "gpu" "mig" "imex" }}
{{- $validDeviceClasses := list "gpu" "mig" "imex" "vfiopci" }}

{{- if not (kindIs "slice" .Values.deviceClasses) }}
{{- $error := "" }}
Expand Down

0 comments on commit f152735

Please sign in to comment.