diff --git a/sriovnet.go b/sriovnet.go index 5e9ff38..2504a6c 100644 --- a/sriovnet.go +++ b/sriovnet.go @@ -520,3 +520,28 @@ func GetPciFromNetDevice(name string) (string, error) { } return base, nil } + +// GetPKeyByIndexFromPci returns the PKey stored under given index for the IB PCI device +func GetPKeyByIndexFromPci(pciAddress string, index int) (string, error) { + pciDir := filepath.Join(PciSysDir, pciAddress, "infiniband") + dirEntries, err := utilfs.Fs.ReadDir(pciDir) + if err != nil { + return "", fmt.Errorf("failed to read infiniband directory: %v", err) + } + if len(dirEntries) == 0 { + return "", fmt.Errorf("infiniband directory is empty for device: %s", pciAddress) + } + + indexFilePath := filepath.Join(pciDir, dirEntries[0].Name(), "ports", "1", "pkeys", strconv.Itoa(index)) + pKeyBytes, err := utilfs.Fs.ReadFile(indexFilePath) + if err != nil { + return "", fmt.Errorf("failed to read PKey file: %v", err) + } + + return strings.TrimSpace(string(pKeyBytes)), nil +} + +// GetDefaultPKeyFromPci returns the index0 PKey for the IB PCI device +func GetDefaultPKeyFromPci(pciAddress string) (string, error) { + return GetPKeyByIndexFromPci(pciAddress, 0) +} diff --git a/sriovnet_test.go b/sriovnet_test.go index 8ed9c4e..8232a9a 100644 --- a/sriovnet_test.go +++ b/sriovnet_test.go @@ -19,6 +19,7 @@ package sriovnet import ( "os" "path/filepath" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -246,3 +247,62 @@ func TestGetPciFromNetDeviceNotPCI(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "is not a PCI device") } + +func TestGetPKeyByIndexFromPci(t *testing.T) { + teardown := setupFakeFs(t) + defer teardown() + + pciAddress := "0000:03:00.2" + pKeysFolder := "/sys/bus/pci/devices/0000:03:00.2/infiniband/mlx5_2/ports/1/pkeys/" + pKeysToIndex := map[string]int{ + "0x55": 2, + "0x8066": 5, + } + + err := utilfs.Fs.MkdirAll(pKeysFolder, os.FileMode(0755)) + assert.NoError(t, err) + for pKey, index := range pKeysToIndex { + file, err := utilfs.Fs.Create(filepath.Join(pKeysFolder, strconv.Itoa(index))) + assert.NoError(t, err) + _, err = file.Write([]byte(pKey)) + assert.NoError(t, err) + err = file.Close() + assert.NoError(t, err) + } + + for expectedPKey, index := range pKeysToIndex { + foundPKey, err := GetPKeyByIndexFromPci(pciAddress, index) + assert.NoError(t, err) + assert.Equal(t, expectedPKey, foundPKey) + } +} + +func TestGetDefaultPKeyFromPci(t *testing.T) { + teardown := setupFakeFs(t) + defer teardown() + + devices := map[string]struct { + path string + pkey string + }{ + "0000:03:00.2": {"/sys/bus/pci/devices/0000:03:00.2/infiniband/mlx5_2/ports/1/pkeys/", "0x66"}, + "0000:03:00.3": {"/sys/bus/pci/devices/0000:03:00.3/infiniband/mlx5_3/ports/1/pkeys/", "0x424"}, + } + + for _, v := range devices { + err := utilfs.Fs.MkdirAll(v.path, os.FileMode(0755)) + assert.NoError(t, err) + file, err := utilfs.Fs.Create(filepath.Join(v.path, "0")) + assert.NoError(t, err) + _, err = file.Write([]byte(v.pkey)) + assert.NoError(t, err) + err = file.Close() + assert.NoError(t, err) + } + + for k, v := range devices { + pKey, err := GetDefaultPKeyFromPci(k) + assert.NoError(t, err) + assert.Equal(t, v.pkey, pKey) + } +}