diff --git a/cmd/kvdpa-cli/kvdpa-cli.go b/cmd/kvdpa-cli/kvdpa-cli.go index 12729a1..4ca3696 100644 --- a/cmd/kvdpa-cli/kvdpa-cli.go +++ b/cmd/kvdpa-cli/kvdpa-cli.go @@ -3,7 +3,6 @@ package main import ( "fmt" "os" - "strings" "text/template" vdpa "github.com/k8snetworkplumbingwg/govdpa/pkg/kvdpa" @@ -26,26 +25,19 @@ const deviceTemplate = ` - Name: {{ .Name }} func listAction(c *cli.Context) error { var devs []vdpa.VdpaDevice var err error - if c.String("mgmtdev") != "" { - var bus, name string - nameParts := strings.Split(c.String("mgmtdev"), "/") - if len(nameParts) == 1 { - name = nameParts[0] - } else if len(nameParts) == 2 { - bus = nameParts[0] - name = nameParts[1] - } else { - return fmt.Errorf("Invalid management device name %s", c.String("mgmtdev")) - } - devs, err = vdpa.GetVdpaDevicesByMgmtDev(bus, name) + var mgmtDev = c.String("mgmtdev") + if mgmtDev != "" { + var busName, devName string + busName, devName, err = vdpa.ExtractBusAndMgmtDevice(mgmtDev) if err != nil { return err } + devs, err = vdpa.GetVdpaDevicesByMgmtDev(busName, devName) } else { devs, err = vdpa.ListVdpaDevices() - if err != nil { - fmt.Println(err) - } + } + if err != nil { + return err } tmpl := template.Must(template.New("device").Parse(deviceTemplate)) diff --git a/pkg/kvdpa/device.go b/pkg/kvdpa/device.go index 120b495..473b2b7 100644 --- a/pkg/kvdpa/device.go +++ b/pkg/kvdpa/device.go @@ -1,10 +1,8 @@ package kvdpa import ( - "errors" "os" "path/filepath" - "strings" "syscall" "github.com/vishvananda/netlink/nl" @@ -185,7 +183,8 @@ func GetVdpaDevice(name string) (VdpaDevice, error) { return nil, err } - vdpaDevs, err := parseDevLinkVdpaDevList(msgs) + // No filters, expecting to parse attributes for the device with the given name + vdpaDevs, err := parseDevLinkVdpaDevList("", "", msgs) if err != nil { return nil, err } @@ -197,50 +196,27 @@ GetVdpaDevicesByMgmtDev returns the VdpaDevice objects whose MgmtDev has the given bus and device names. */ func GetVdpaDevicesByMgmtDev(busName, devName string) ([]VdpaDevice, error) { - result := []VdpaDevice{} - devices, err := ListVdpaDevices() - if err != nil { - return nil, err - } - for _, device := range devices { - if device.MgmtDev() != nil && - device.MgmtDev().BusName() == busName && - device.MgmtDev().DevName() == devName { - result = append(result, device) - } - } - if len(result) == 0 { - return nil, syscall.ENODEV - } - return result, nil + return listVdpaDevicesWithBusDevName(busName, devName) } /*ListVdpaDevices returns a list of all available vdpa devices */ func ListVdpaDevices() ([]VdpaDevice, error) { + return listVdpaDevicesWithBusDevName("", "") +} + +func listVdpaDevicesWithBusDevName(busName, devName string) ([]VdpaDevice, error) { msgs, err := GetNetlinkOps().RunVdpaNetlinkCmd(VdpaCmdDevGet, syscall.NLM_F_DUMP, nil) if err != nil { return nil, err } - vdpaDevs, err := parseDevLinkVdpaDevList(msgs) + vdpaDevs, err := parseDevLinkVdpaDevList(busName, devName, msgs) if err != nil { return nil, err } return vdpaDevs, nil } -func extractBusNameAndMgmtDeviceName(fullMgmtDeviceName string) (busName string, mgmtDeviceName string, err error) { - numSlashes := strings.Count(fullMgmtDeviceName, "/") - if numSlashes > 1 { - return "", "", errors.New("expected mgmtDeviceName to be either in the format / or ") - } else if numSlashes == 0 { - return "", fullMgmtDeviceName, nil - } else { - values := strings.Split(fullMgmtDeviceName, "/") - return values[0], values[1], nil - } -} - /* GetVdpaDevicesByPciAddress returns the VdpaDevice objects for the given pciAddress @@ -249,7 +225,7 @@ GetVdpaDevicesByPciAddress returns the VdpaDevice objects for the given pciAddre - MgmtDevName */ func GetVdpaDevicesByPciAddress(pciAddress string) ([]VdpaDevice, error) { - busName, mgmtDeviceName, err := extractBusNameAndMgmtDeviceName(pciAddress) + busName, mgmtDeviceName, err := ExtractBusAndMgmtDevice(pciAddress) if err != nil { return nil, unix.EINVAL } @@ -263,7 +239,7 @@ func AddVdpaDevice(mgmtDeviceName string, vdpaDeviceName string) error { return unix.EINVAL } - busName, mgmtDeviceName, err := extractBusNameAndMgmtDeviceName(mgmtDeviceName) + busName, mgmtDeviceName, err := ExtractBusAndMgmtDevice(mgmtDeviceName) if err != nil { return unix.EINVAL } @@ -317,7 +293,7 @@ func DeleteVdpaDevice(vdpaDeviceName string) error { return nil } -func parseDevLinkVdpaDevList(msgs [][]byte) ([]VdpaDevice, error) { +func parseDevLinkVdpaDevList(busName string, mgmtDeviceName string, msgs [][]byte) ([]VdpaDevice, error) { devices := make([]VdpaDevice, 0, len(msgs)) for _, m := range msgs { @@ -329,6 +305,15 @@ func parseDevLinkVdpaDevList(msgs [][]byte) ([]VdpaDevice, error) { if err = dev.parseAttributes(attrs); err != nil { return nil, err } + + if busName != "" && busName != dev.mgmtDev.busName { + continue + } + + if mgmtDeviceName != "" && mgmtDeviceName != dev.mgmtDev.devName { + continue + } + if err = dev.getBusInfo(); err != nil { return nil, err } diff --git a/pkg/kvdpa/device_test.go b/pkg/kvdpa/device_test.go index aaa0668..b584105 100644 --- a/pkg/kvdpa/device_test.go +++ b/pkg/kvdpa/device_test.go @@ -128,6 +128,98 @@ func TestVdpaDevList(t *testing.T) { } } +func TestVdpaDevListWithFilter(t *testing.T) { + tests := []struct { + name string + err bool + response []VdpaDevice + }{ + { + name: "Multiple SR-IOV and SF devices", + err: false, + response: []VdpaDevice{ + &vdpaDev{ + name: "vdpa0", + mgmtDev: &mgmtDev{ + devName: "0000:01:01", + }, + }, + &vdpaDev{ + name: "vdpa1", + mgmtDev: &mgmtDev{ + busName: "pci", + devName: "0000:01:02", + }, + }, + &vdpaDev{ + name: "vdpa2", + mgmtDev: &mgmtDev{ + busName: "pci", + devName: "0000:01:02", + }, + }, + &vdpaDev{ + name: "vdpa3", + mgmtDev: &mgmtDev{ + busName: "pci", + devName: "0000:01:03", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%s", "TestVdpaDevListWithFilter", tt.name), func(t *testing.T) { + netLinkMock := &mocks.NetlinkOps{} + SetNetlinkOps(netLinkMock) + netLinkMock.On("RunVdpaNetlinkCmd", + VdpaCmdDevGet, + mock.MatchedBy(func(flags int) bool { + return (flags|syscall.NLM_F_DUMP != 0) + }), + mock.AnythingOfType("[]*nl.RtAttr")). + Return(vdpaDevToNlMessage(t, tt.response...), nil) + // no filters, all devices are returned + devs, err := ListVdpaDevices() + if tt.err { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, tt.response, devs) + } + // mgmtdev: 0000:01:01 + devs, err = GetVdpaDevicesByPciAddress("0000:01:01") + if tt.err { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, len(devs), 1) + assert.Equal(t, tt.response[0], devs[0]) + } + // mgmtdev: pci/0000:01:02 + devs, err = GetVdpaDevicesByPciAddress("pci/0000:01:02") + if tt.err { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, len(devs), 2) + assert.Equal(t, tt.response[1], devs[0]) + assert.Equal(t, tt.response[2], devs[1]) + } + // mgmtdev: pci/0000:01:03 + devs, err = GetVdpaDevicesByPciAddress("pci/0000:01:03") + if tt.err { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, len(devs), 1) + assert.Equal(t, tt.response[3], devs[0]) + } + }) + } +} + func TestVdpaDevGet(t *testing.T) { tests := []struct { name string @@ -304,7 +396,7 @@ func TestVdpaDevGetByMgmt(t *testing.T) { }, { name: "Wrong", - err: syscall.ENODEV, + response: []VdpaDevice{}, mgmtDevName: "noDev", mgmtBusName: "wrongBus", }, diff --git a/pkg/kvdpa/util.go b/pkg/kvdpa/util.go new file mode 100644 index 0000000..8cc7173 --- /dev/null +++ b/pkg/kvdpa/util.go @@ -0,0 +1,22 @@ +package kvdpa + +import ( + "errors" + "strings" +) + +// ExtractBusAndMgmtDevice extracts the busName and deviceName from a full device address (e.g. pci) +// example 1: pci/65:0000.1 -> "pci", "65:0000.1", nil +// example 2: vdpa_sim -> "", "vdpa_sim", nil +// example 3: pci/65:0000.1/1 -> "", "", err +func ExtractBusAndMgmtDevice(fullMgmtDeviceName string) (busName string, mgmtDeviceName string, err error) { + numSlashes := strings.Count(fullMgmtDeviceName, "/") + if numSlashes > 1 { + return "", "", errors.New("expected mgmtDeviceName to be either in the format / or ") + } else if numSlashes == 0 { + return "", fullMgmtDeviceName, nil + } else { + values := strings.Split(fullMgmtDeviceName, "/") + return values[0], values[1], nil + } +} diff --git a/pkg/kvdpa/util_test.go b/pkg/kvdpa/util_test.go new file mode 100644 index 0000000..c62c964 --- /dev/null +++ b/pkg/kvdpa/util_test.go @@ -0,0 +1,53 @@ +package kvdpa + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractBusAndMgmtDevice(t *testing.T) { + tests := []struct { + testName string + deviceAddress string + busName string + devName string + err bool + }{ + { + testName: "regular PCI address", + deviceAddress: "pci/0000:65:00.1", + busName: "pci", + devName: "0000:65:00.1", + err: false, + }, + { + testName: "no bus", + deviceAddress: "vdpa_sim", + busName: "", + devName: "vdpa_sim", + err: false, + }, + { + testName: "wrong address", + deviceAddress: "pci/0000:65:00.1/0", + busName: "", + devName: "", + err: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%s", "TestExtractBusAndMgmtDevice", tt.testName), func(t *testing.T) { + busName, devName, err := ExtractBusAndMgmtDevice(tt.deviceAddress) + if tt.err { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, tt.busName, busName) + assert.Equal(t, tt.devName, devName) + } + }) + } +}