Skip to content

Commit

Permalink
adding multiple other devices fake device plugins
Browse files Browse the repository at this point in the history
enoodle committed Nov 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 83b94a7 commit 22dc131
Showing 5 changed files with 67 additions and 22 deletions.
11 changes: 7 additions & 4 deletions cmd/device-plugin/main.go
Original file line number Diff line number Diff line change
@@ -42,10 +42,13 @@ func main() {
initNvidiaSmi()
initPreloaders()

devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient)
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
devicePlugins := deviceplugin.NewDevicePlugins(topology, kubeClient)
for _, devicePlugin := range devicePlugins {
log.Printf("Starting device plugin for %s\n", devicePlugin.Name())
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
}
}

sig := make(chan os.Signal, 1)
21 changes: 14 additions & 7 deletions internal/common/topology/types.go
Original file line number Diff line number Diff line change
@@ -14,16 +14,18 @@ type ClusterTopology struct {
}

type NodePoolTopology struct {
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"`
}

type NodeTopology struct {
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"`
}

type GpuDetails struct {
@@ -56,6 +58,11 @@ type Range struct {
Max int `yaml:"max"`
}

type GenericDevice struct {
Name string `yaml:"name"`
Count int `yaml:"count"`
}

// Errors
var ErrNoNodes = fmt.Errorf("no nodes found")
var ErrNoNode = fmt.Errorf("node not found")
39 changes: 32 additions & 7 deletions internal/deviceplugin/device_plugin.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,59 @@
package deviceplugin

import (
"path"
"strings"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
"github.com/spf13/viper"
"k8s.io/client-go/kubernetes"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

const (
resourceName = "nvidia.com/gpu"
nvidiaGPUResourceName = "nvidia.com/gpu"
)

type Interface interface {
Serve() error
Name() string
}

func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface {
func NewDevicePlugins(topology *topology.NodeTopology, kubeClient kubernetes.Interface) []Interface {
if topology == nil {
panic("topology is nil")
}

if viper.GetBool(constants.EnvFakeNode) {
return &FakeNodeDevicePlugin{
return []Interface{&FakeNodeDevicePlugin{
kubeClient: kubeClient,
gpuCount: getGpuCount(topology),
}
}}
}

devicePlugins := []Interface{
&RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
resourceName: nvidiaGPUResourceName,
},
}

return &RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
for _, genericDevice := range topology.OtherDevices {
devicePlugins = append(devicePlugins, &RealNodeDevicePlugin{
devs: createDevices(genericDevice.Count),
socket: path.Join(pluginapi.DevicePluginPath, normalizeDeviceName(genericDevice.Name)+".sock"),
resourceName: genericDevice.Name,
})
}

return devicePlugins
}

func normalizeDeviceName(deviceName string) string {
normalized := strings.ReplaceAll(deviceName, "/", "_")
normalized = strings.ReplaceAll(normalized, ".", "_")
normalized = strings.ReplaceAll(normalized, "-", "_")
return normalized
}
6 changes: 5 additions & 1 deletion internal/deviceplugin/fake_node.go
Original file line number Diff line number Diff line change
@@ -18,11 +18,15 @@ type FakeNodeDevicePlugin struct {
}

func (f *FakeNodeDevicePlugin) Serve() error {
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount)
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, nvidiaGPUResourceName, f.gpuCount, nvidiaGPUResourceName, f.gpuCount)
_, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status")
if err != nil {
return fmt.Errorf("failed to update node capacity and allocatable: %v", err)
}

return nil
}

func (f *FakeNodeDevicePlugin) Name() string {
return "FakeNodeDevicePlugin"
}
12 changes: 9 additions & 3 deletions internal/deviceplugin/real_node.go
Original file line number Diff line number Diff line change
@@ -28,6 +28,8 @@ type RealNodeDevicePlugin struct {
stop chan interface{}
health chan *pluginapi.Device
server *grpc.Server

resourceName string
}

func getGpuCount(nodeTopology *topology.NodeTopology) int {
@@ -115,7 +117,7 @@ func (m *RealNodeDevicePlugin) Stop() error {
return m.cleanup()
}

func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error {
func (m *RealNodeDevicePlugin) Register(kubeletEndpoint string) error {
conn, err := dial(kubeletEndpoint, 5*time.Second)
if err != nil {
return err
@@ -126,7 +128,7 @@ func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) er
reqt := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: path.Base(m.socket),
ResourceName: resourceName,
ResourceName: m.resourceName,
}

_, err = client.Register(context.Background(), reqt)
@@ -202,7 +204,7 @@ func (m *RealNodeDevicePlugin) Serve() error {
}
log.Println("Starting to serve on", m.socket)

err = m.Register(pluginapi.KubeletSocket, resourceName)
err = m.Register(pluginapi.KubeletSocket)
if err != nil {
log.Printf("Could not register device plugin: %s", err)
stopErr := m.Stop()
@@ -215,3 +217,7 @@ func (m *RealNodeDevicePlugin) Serve() error {

return nil
}

func (m *RealNodeDevicePlugin) Name() string {
return "RealNodeDevicePlugin-" + m.resourceName
}

0 comments on commit 22dc131

Please sign in to comment.