diff --git a/go.mod b/go.mod index 93a75709..d52924ab 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 toolchain go1.22.5 require ( - github.com/NVIDIA/go-nvlib v0.3.0 + github.com/NVIDIA/go-nvlib v0.6.0 github.com/NVIDIA/go-nvml v0.12.4-0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 53c58ae2..ba4a728a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/NVIDIA/go-nvlib v0.3.0 h1:vd7jSOthJTqzqIWZrv317xDr1+Mnjoy5X4N69W9YwQM= -github.com/NVIDIA/go-nvlib v0.3.0/go.mod h1:NasUuId9hYFvwzuOHCu9F2X6oTU2tG0JHTfbJYuDAbA= +github.com/NVIDIA/go-nvlib v0.6.0 h1:zAMBzCYT9xeyRQo0tb7HJbStkzajD6e5joyaQqJ2OGU= +github.com/NVIDIA/go-nvlib v0.6.0/go.mod h1:9UrsLGx/q1OrENygXjOuM5Ey5KCtiZhbvBlbUIxtGWY= github.com/NVIDIA/go-nvml v0.12.4-0 h1:4tkbB3pT1O77JGr0gQ6uD8FrsUPqP1A/EOEm2wI1TUg= github.com/NVIDIA/go-nvml v0.12.4-0/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= diff --git a/pkg/mig/mode/pci_test.go b/pkg/mig/mode/pci_test.go index fe580497..7c06147f 100644 --- a/pkg/mig/mode/pci_test.go +++ b/pkg/mig/mode/pci_test.go @@ -36,7 +36,7 @@ func NewMockPciA100Device() (*mockPciMigModeManager, error) { return nil, fmt.Errorf("error creating Mock A100 PCI device: %v", err) } - err = nvpci.AddMockA100("0000:80:05.1", 0) + err = nvpci.AddMockA100("0000:80:05.1", 0, nil) if err != nil { return nil, fmt.Errorf("error adding Mock A100 device to MockNvpci: %v", err) } diff --git a/pkg/types/nvdev.go b/pkg/types/nvdev.go index 1506aaae..b785adf2 100644 --- a/pkg/types/nvdev.go +++ b/pkg/types/nvdev.go @@ -41,7 +41,7 @@ func nvdevNewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemoryS nvmllib = nvml.New() } if nvdevlib == nil { - nvdevlib = nvdev.New(nvdev.WithNvml(nvmllib)) + nvdevlib = nvdev.New(nvmllib) } mp, err := nvdevlib.NewMigProfile(giProfileID, ciProfileID, ciEngProfileID, migMemorySizeMB, deviceMemorySizeBytes) @@ -57,7 +57,7 @@ func nvdevAssertValidMigProfileFormat(profile string) error { nvmllib = nvml.New() } if nvdevlib == nil { - nvdevlib = nvdev.New(nvdev.WithNvml(nvmllib)) + nvdevlib = nvdev.New(nvmllib) } return nvdevlib.AssertValidMigProfileFormat(profile) @@ -68,7 +68,7 @@ func nvdevParseMigProfile(profile string) (nvdev.MigProfile, error) { nvmllib = nvml.New() } if nvdevlib == nil { - nvdevlib = nvdev.New(nvdev.WithNvml(nvmllib)) + nvdevlib = nvdev.New(nvmllib) } ret := nvmllib.Init() @@ -141,8 +141,7 @@ func SetMockNVdevlib() { }, } - nvdevlib = nvdev.New( - nvdev.WithNvml(nvmllib), + nvdevlib = nvdev.New(nvmllib, nvdev.WithVerifySymbols(false), ) } diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go index 11aa139d..c2a6517d 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go @@ -38,7 +38,7 @@ type Interface interface { } type devicelib struct { - nvml nvml.Interface + nvmllib nvml.Interface skippedDevices map[string]struct{} verifySymbols *bool migProfiles []MigProfile @@ -47,14 +47,13 @@ type devicelib struct { var _ Interface = &devicelib{} // New creates a new instance of the 'device' interface. -func New(opts ...Option) Interface { - d := &devicelib{} +func New(nvmllib nvml.Interface, opts ...Option) Interface { + d := &devicelib{ + nvmllib: nvmllib, + } for _, opt := range opts { opt(d) } - if d.nvml == nil { - d.nvml = nvml.New() - } if d.verifySymbols == nil { verify := true d.verifySymbols = &verify @@ -68,13 +67,6 @@ func New(opts ...Option) Interface { return d } -// WithNvml provides an Option to set the NVML library used by the 'device' interface. -func WithNvml(nvml nvml.Interface) Option { - return func(d *devicelib) { - d.nvml = nvml - } -} - // WithVerifySymbols provides an option to toggle whether to verify select symbols exist in dynamic libraries before calling them. func WithVerifySymbols(verify bool) Option { return func(d *devicelib) { diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go index 10514591..5b21fc13 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go @@ -18,6 +18,7 @@ package device import ( "fmt" + "strings" "github.com/NVIDIA/go-nvml/pkg/nvml" ) @@ -30,6 +31,7 @@ type Device interface { GetCudaComputeCapabilityAsString() (string, error) GetMigDevices() ([]MigDevice, error) GetMigProfiles() ([]MigProfile, error) + GetPCIBusID() (string, error) IsMigCapable() (bool, error) IsMigEnabled() (bool, error) VisitMigDevices(func(j int, m MigDevice) error) error @@ -51,7 +53,7 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { // NewDeviceByUUID builds a new Device from a UUID. func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) { - dev, ret := d.nvml.DeviceGetHandleByUUID(uuid) + dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid) if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret) } @@ -140,6 +142,29 @@ func (d *device) GetBrandAsString() (string, error) { return "", fmt.Errorf("error interpreting device brand as string: %v", brand) } +// GetPCIBusID returns the string representation of the bus ID. +func (d *device) GetPCIBusID() (string, error) { + info, ret := d.GetPciInfo() + if ret != nvml.SUCCESS { + return "", fmt.Errorf("error getting PCI info: %w", ret) + } + + var bytes []byte + for _, b := range info.BusId { + if byte(b) == '\x00' { + break + } + bytes = append(bytes, byte(b)) + } + id := strings.ToLower(string(bytes)) + + if id != "0000" { + id = strings.TrimPrefix(id, "0000") + } + + return id, nil +} + // GetCudaComputeCapabilityAsString returns the Device's CUDA compute capability as a version string. func (d *device) GetCudaComputeCapabilityAsString() (string, error) { major, minor, ret := d.GetCudaComputeCapability() @@ -334,13 +359,13 @@ func (d *device) isSkipped() (bool, error) { // VisitDevices visits each top-level device and invokes a callback function for it. func (d *devicelib) VisitDevices(visit func(int, Device) error) error { - count, ret := d.nvml.DeviceGetCount() + count, ret := d.nvmllib.DeviceGetCount() if ret != nvml.SUCCESS { return fmt.Errorf("error getting device count: %v", ret) } for i := 0; i < count; i++ { - device, ret := d.nvml.DeviceGetHandleByIndex(i) + device, ret := d.nvmllib.DeviceGetHandleByIndex(i) if ret != nvml.SUCCESS { return fmt.Errorf("error getting device handle for index '%v': %v", i, ret) } @@ -469,5 +494,5 @@ func (d *devicelib) hasSymbol(symbol string) bool { return true } - return d.nvml.Extensions().LookupSymbol(symbol) == nil + return d.nvmllib.Extensions().LookupSymbol(symbol) == nil } diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go index b02d4176..7145a06b 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go @@ -50,7 +50,7 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) { // NewMigDeviceByUUID builds a new MigDevice from a UUID. func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) { - dev, ret := d.nvml.DeviceGetHandleByUUID(uuid) + dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid) if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret) } diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mock.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mock.go index 7c1b69dd..9b3d6e2a 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mock.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mock.go @@ -20,6 +20,8 @@ import ( "fmt" "os" "path/filepath" + "regexp" + "strconv" "github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes" ) @@ -55,64 +57,114 @@ func (m *MockNvpci) Cleanup() { os.RemoveAll(m.pciDevicesRoot) } +func validatePCIAddress(addr string) error { + r := regexp.MustCompile(`0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]`) + if !r.Match([]byte(addr)) { + return fmt.Errorf(`invalid PCI address should match 0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]: %s`, addr) + } + + return nil +} + // AddMockA100 Create an A100 like GPU mock device. -func (m *MockNvpci) AddMockA100(address string, numaNode int) error { +func (m *MockNvpci) AddMockA100(address string, numaNode int, sriov *SriovInfo) error { + err := validatePCIAddress(address) + if err != nil { + return err + } + deviceDir := filepath.Join(m.pciDevicesRoot, address) - err := os.MkdirAll(deviceDir, 0755) + err = os.MkdirAll(deviceDir, 0755) if err != nil { return err } - vendor, err := os.Create(filepath.Join(deviceDir, "vendor")) + err = createNVIDIAgpuFiles(deviceDir) if err != nil { return err } - _, err = vendor.WriteString(fmt.Sprintf("0x%x", PCINvidiaVendorID)) + + iommuGroup := 20 + _, err = os.Create(filepath.Join(deviceDir, strconv.Itoa(iommuGroup))) + if err != nil { + return err + } + err = os.Symlink(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)), filepath.Join(deviceDir, "iommu_group")) if err != nil { return err } - class, err := os.Create(filepath.Join(deviceDir, "class")) + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) if err != nil { return err } - _, err = class.WriteString(fmt.Sprintf("0x%x", PCI3dControllerClass)) + _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) if err != nil { return err } - device, err := os.Create(filepath.Join(deviceDir, "device")) + if sriov != nil && sriov.PhysicalFunction != nil { + totalVFs, err := os.Create(filepath.Join(deviceDir, "sriov_totalvfs")) + if err != nil { + return err + } + _, err = fmt.Fprintf(totalVFs, "%d", sriov.PhysicalFunction.TotalVFs) + if err != nil { + return err + } + + numVFs, err := os.Create(filepath.Join(deviceDir, "sriov_numvfs")) + if err != nil { + return err + } + _, err = fmt.Fprintf(numVFs, "%d", sriov.PhysicalFunction.NumVFs) + if err != nil { + return err + } + for i := 1; i <= int(sriov.PhysicalFunction.NumVFs); i++ { + err = m.createVf(address, i, iommuGroup, numaNode) + if err != nil { + return err + } + } + } + + return nil +} + +func createNVIDIAgpuFiles(deviceDir string) error { + vendor, err := os.Create(filepath.Join(deviceDir, "vendor")) if err != nil { return err } - _, err = device.WriteString("0x20bf") + _, err = vendor.WriteString(fmt.Sprintf("0x%x", PCINvidiaVendorID)) if err != nil { return err } - _, err = os.Create(filepath.Join(deviceDir, "nvidia")) + class, err := os.Create(filepath.Join(deviceDir, "class")) if err != nil { return err } - err = os.Symlink(filepath.Join(deviceDir, "nvidia"), filepath.Join(deviceDir, "driver")) + _, err = class.WriteString(fmt.Sprintf("0x%x", PCI3dControllerClass)) if err != nil { return err } - _, err = os.Create(filepath.Join(deviceDir, "20")) + device, err := os.Create(filepath.Join(deviceDir, "device")) if err != nil { return err } - err = os.Symlink(filepath.Join(deviceDir, "20"), filepath.Join(deviceDir, "iommu_group")) + _, err = device.WriteString("0x20bf") if err != nil { return err } - numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) + _, err = os.Create(filepath.Join(deviceDir, "nvidia")) if err != nil { return err } - _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) + err = os.Symlink(filepath.Join(deviceDir, "nvidia"), filepath.Join(deviceDir, "driver")) if err != nil { return err } @@ -156,3 +208,53 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error { return nil } + +func (m *MockNvpci) createVf(pfAddress string, id, iommu_group, numaNode int) error { + functionID := pfAddress[len(pfAddress)-1] + // we are verifying the last character of pfAddress is integer. + functionNumber, err := strconv.Atoi(string(functionID)) + if err != nil { + return fmt.Errorf("can't conver physical function pci address function number %s to integer: %v", string(functionID), err) + } + + vfFunctionNumber := functionNumber + id + vfAddress := pfAddress[:len(pfAddress)-1] + strconv.Itoa(vfFunctionNumber) + + deviceDir := filepath.Join(m.pciDevicesRoot, vfAddress) + err = os.MkdirAll(deviceDir, 0755) + if err != nil { + return err + } + + err = createNVIDIAgpuFiles(deviceDir) + if err != nil { + return err + } + + vfIommuGroup := strconv.Itoa(iommu_group + id) + + _, err = os.Create(filepath.Join(deviceDir, vfIommuGroup)) + if err != nil { + return err + } + err = os.Symlink(filepath.Join(deviceDir, vfIommuGroup), filepath.Join(deviceDir, "iommu_group")) + if err != nil { + return err + } + + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) + if err != nil { + return err + } + _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) + if err != nil { + return err + } + + err = os.Symlink(filepath.Join(m.pciDevicesRoot, pfAddress), filepath.Join(deviceDir, "physfn")) + if err != nil { + return err + } + + return nil +} diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go index 6d83a577..6ff197b1 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go @@ -76,6 +76,32 @@ type nvpci struct { var _ Interface = (*nvpci)(nil) var _ ResourceInterface = (*MemoryResources)(nil) +// SriovInfo indicates whether device is VF/PF for SRIOV capable devices. +// Only one should be set at any given time. +type SriovInfo struct { + PhysicalFunction *SriovPhysicalFunction + VirtualFunction *SriovVirtualFunction +} + +// SriovPhysicalFunction stores info about SRIOV physical function. +type SriovPhysicalFunction struct { + TotalVFs uint64 + NumVFs uint64 +} + +// SriovVirtualFunction keeps data about SRIOV virtual function. +type SriovVirtualFunction struct { + PhysicalFunction *NvidiaPCIDevice +} + +func (s *SriovInfo) IsPF() bool { + return s != nil && s.PhysicalFunction != nil +} + +func (s *SriovInfo) IsVF() bool { + return s != nil && s.VirtualFunction != nil +} + // NvidiaPCIDevice represents a PCI device for an NVIDIA product. type NvidiaPCIDevice struct { Path string @@ -90,7 +116,7 @@ type NvidiaPCIDevice struct { NumaNode int Config *ConfigSpace Resources MemoryResources - IsVF bool + SriovInfo SriovInfo } // IsVGAController if class == 0x300. @@ -178,9 +204,11 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { } var nvdevices []*NvidiaPCIDevice + // Cache devices for each GetAllDevices invocation to speed things up. + cache := make(map[string]*NvidiaPCIDevice) for _, deviceDir := range deviceDirs { deviceAddress := deviceDir.Name() - nvdevice, err := p.GetGPUByPciBusID(deviceAddress) + nvdevice, err := p.getGPUByPciBusID(deviceAddress, cache) if err != nil { return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err) } @@ -206,6 +234,16 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { // GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID). func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { + // Pass nil as to force reading device information from sysfs. + return p.getGPUByPciBusID(address, nil) +} + +func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevice) (*NvidiaPCIDevice, error) { + if cache != nil { + if pciDevice, exists := cache[address]; exists { + return pciDevice, nil + } + } devicePath := filepath.Join(p.pciDevicesRoot, address) vendor, err := os.ReadFile(path.Join(devicePath, "vendor")) @@ -265,16 +303,6 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err) } - // device is a virtual function (VF) if "physfn" symlink exists. - var isVF bool - _, err = filepath.EvalSymlinks(path.Join(devicePath, "physfn")) - if err == nil { - isVF = true - } - if err != nil && !os.IsNotExist(err) { - return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(devicePath, "physfn"), err) - } - numa, err := os.ReadFile(path.Join(devicePath, "numa_node")) if err != nil { return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err) @@ -328,6 +356,28 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { className = UnknownClassString } + var sriovInfo SriovInfo + // Device is a virtual function (VF) if "physfn" symlink exists. + physFnAddress, err := filepath.EvalSymlinks(path.Join(devicePath, "physfn")) + if err == nil { + physFn, err := p.getGPUByPciBusID(filepath.Base(physFnAddress), cache) + if err != nil { + return nil, fmt.Errorf("unable to detect physfn for %s: %v", address, err) + } + sriovInfo = SriovInfo{ + VirtualFunction: &SriovVirtualFunction{ + PhysicalFunction: physFn, + }, + } + } else if os.IsNotExist(err) { + sriovInfo, err = p.getSriovInfoForPhysicalFunction(devicePath) + if err != nil { + return nil, fmt.Errorf("unable to read SRIOV physical function details for %s: %v", devicePath, err) + } + } else { + return nil, fmt.Errorf("unable to read %s: %v", path.Join(devicePath, "physfn"), err) + } + nvdevice := &NvidiaPCIDevice{ Path: devicePath, Address: address, @@ -339,9 +389,14 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { NumaNode: int(numaNode), Config: config, Resources: resources, - IsVF: isVF, DeviceName: deviceName, ClassName: className, + SriovInfo: sriovInfo, + } + + // Cache physical functions only as VF can't be a root device. + if cache != nil && sriovInfo.IsPF() { + cache[address] = nvdevice } return nvdevice, nil @@ -407,7 +462,7 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) { var filtered []*NvidiaPCIDevice for _, d := range devices { - if d.IsGPU() && !d.IsVF { + if d.IsGPU() && !d.SriovInfo.IsVF() { filtered = append(filtered, d) } } @@ -428,3 +483,41 @@ func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) { return gpus[i], nil } + +func (p *nvpci) getSriovInfoForPhysicalFunction(devicePath string) (sriovInfo SriovInfo, err error) { + totalVfsPath := filepath.Join(devicePath, "sriov_totalvfs") + numVfsPath := filepath.Join(devicePath, "sriov_numvfs") + + // No file for sriov_totalvfs exists? Not an SRIOV device, return nil + _, err = os.Stat(totalVfsPath) + if err != nil && os.IsNotExist(err) { + return sriovInfo, nil + } + sriovTotalVfs, err := os.ReadFile(totalVfsPath) + if err != nil { + return sriovInfo, fmt.Errorf("unable to read sriov_totalvfs: %v", err) + } + totalVfsStr := strings.TrimSpace(string(sriovTotalVfs)) + totalVfsInt, err := strconv.ParseUint(totalVfsStr, 10, 16) + if err != nil { + return sriovInfo, fmt.Errorf("unable to convert sriov_totalvfs to uint64: %v", err) + } + + sriovNumVfs, err := os.ReadFile(numVfsPath) + if err != nil { + return sriovInfo, fmt.Errorf("unable to read sriov_numvfs for: %v", err) + } + numVfsStr := strings.TrimSpace(string(sriovNumVfs)) + numVfsInt, err := strconv.ParseUint(numVfsStr, 10, 16) + if err != nil { + return sriovInfo, fmt.Errorf("unable to convert sriov_numvfs to uint64: %v", err) + } + + sriovInfo = SriovInfo{ + PhysicalFunction: &SriovPhysicalFunction{ + TotalVFs: totalVfsInt, + NumVFs: numVfsInt, + }, + } + return sriovInfo, nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt index a49f70e1..241ab98f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# github.com/NVIDIA/go-nvlib v0.3.0 +# github.com/NVIDIA/go-nvlib v0.6.0 ## explicit; go 1.20 github.com/NVIDIA/go-nvlib/pkg/nvlib/device github.com/NVIDIA/go-nvlib/pkg/nvpci