Skip to content

Commit

Permalink
Merge pull request #20 from lmilleri/vdpa-device-by-pci
Browse files Browse the repository at this point in the history
Internal refactoring to avoid unneeded netlink parsing
  • Loading branch information
amorenoz authored Sep 26, 2023
2 parents ff4e4ec + 3335d73 commit 07c1031
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 52 deletions.
24 changes: 8 additions & 16 deletions cmd/kvdpa-cli/kvdpa-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"fmt"
"os"
"strings"
"text/template"

vdpa "github.com/k8snetworkplumbingwg/govdpa/pkg/kvdpa"
Expand All @@ -26,26 +25,19 @@ const deviceTemplate = ` - Name: {{ .Name }}
func listAction(c *cli.Context) error {
var devs []vdpa.VdpaDevice
var err error
if c.String("mgmtdev") != "" {
var bus, name string
nameParts := strings.Split(c.String("mgmtdev"), "/")
if len(nameParts) == 1 {
name = nameParts[0]
} else if len(nameParts) == 2 {
bus = nameParts[0]
name = nameParts[1]
} else {
return fmt.Errorf("Invalid management device name %s", c.String("mgmtdev"))
}
devs, err = vdpa.GetVdpaDevicesByMgmtDev(bus, name)
var mgmtDev = c.String("mgmtdev")
if mgmtDev != "" {
var busName, devName string
busName, devName, err = vdpa.ExtractBusAndMgmtDevice(mgmtDev)
if err != nil {
return err
}
devs, err = vdpa.GetVdpaDevicesByMgmtDev(busName, devName)
} else {
devs, err = vdpa.ListVdpaDevices()
if err != nil {
fmt.Println(err)
}
}
if err != nil {
return err
}
tmpl := template.Must(template.New("device").Parse(deviceTemplate))

Expand Down
55 changes: 20 additions & 35 deletions pkg/kvdpa/device.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package kvdpa

import (
"errors"
"os"
"path/filepath"
"strings"
"syscall"

"github.com/vishvananda/netlink/nl"
Expand Down Expand Up @@ -185,7 +183,8 @@ func GetVdpaDevice(name string) (VdpaDevice, error) {
return nil, err
}

vdpaDevs, err := parseDevLinkVdpaDevList(msgs)
// No filters, expecting to parse attributes for the device with the given name
vdpaDevs, err := parseDevLinkVdpaDevList("", "", msgs)
if err != nil {
return nil, err
}
Expand All @@ -197,50 +196,27 @@ GetVdpaDevicesByMgmtDev returns the VdpaDevice objects whose MgmtDev
has the given bus and device names.
*/
func GetVdpaDevicesByMgmtDev(busName, devName string) ([]VdpaDevice, error) {
result := []VdpaDevice{}
devices, err := ListVdpaDevices()
if err != nil {
return nil, err
}
for _, device := range devices {
if device.MgmtDev() != nil &&
device.MgmtDev().BusName() == busName &&
device.MgmtDev().DevName() == devName {
result = append(result, device)
}
}
if len(result) == 0 {
return nil, syscall.ENODEV
}
return result, nil
return listVdpaDevicesWithBusDevName(busName, devName)
}

/*ListVdpaDevices returns a list of all available vdpa devices */
func ListVdpaDevices() ([]VdpaDevice, error) {
return listVdpaDevicesWithBusDevName("", "")
}

func listVdpaDevicesWithBusDevName(busName, devName string) ([]VdpaDevice, error) {
msgs, err := GetNetlinkOps().RunVdpaNetlinkCmd(VdpaCmdDevGet, syscall.NLM_F_DUMP, nil)
if err != nil {
return nil, err
}

vdpaDevs, err := parseDevLinkVdpaDevList(msgs)
vdpaDevs, err := parseDevLinkVdpaDevList(busName, devName, msgs)
if err != nil {
return nil, err
}
return vdpaDevs, nil
}

func extractBusNameAndMgmtDeviceName(fullMgmtDeviceName string) (busName string, mgmtDeviceName string, err error) {
numSlashes := strings.Count(fullMgmtDeviceName, "/")
if numSlashes > 1 {
return "", "", errors.New("expected mgmtDeviceName to be either in the format <mgmtBusName>/<mgmtDeviceName> or <mgmtDeviceName>")
} else if numSlashes == 0 {
return "", fullMgmtDeviceName, nil
} else {
values := strings.Split(fullMgmtDeviceName, "/")
return values[0], values[1], nil
}
}

/*
GetVdpaDevicesByPciAddress returns the VdpaDevice objects for the given pciAddress
Expand All @@ -249,7 +225,7 @@ GetVdpaDevicesByPciAddress returns the VdpaDevice objects for the given pciAddre
- MgmtDevName
*/
func GetVdpaDevicesByPciAddress(pciAddress string) ([]VdpaDevice, error) {
busName, mgmtDeviceName, err := extractBusNameAndMgmtDeviceName(pciAddress)
busName, mgmtDeviceName, err := ExtractBusAndMgmtDevice(pciAddress)
if err != nil {
return nil, unix.EINVAL
}
Expand All @@ -263,7 +239,7 @@ func AddVdpaDevice(mgmtDeviceName string, vdpaDeviceName string) error {
return unix.EINVAL
}

busName, mgmtDeviceName, err := extractBusNameAndMgmtDeviceName(mgmtDeviceName)
busName, mgmtDeviceName, err := ExtractBusAndMgmtDevice(mgmtDeviceName)
if err != nil {
return unix.EINVAL
}
Expand Down Expand Up @@ -317,7 +293,7 @@ func DeleteVdpaDevice(vdpaDeviceName string) error {
return nil
}

func parseDevLinkVdpaDevList(msgs [][]byte) ([]VdpaDevice, error) {
func parseDevLinkVdpaDevList(busName string, mgmtDeviceName string, msgs [][]byte) ([]VdpaDevice, error) {
devices := make([]VdpaDevice, 0, len(msgs))

for _, m := range msgs {
Expand All @@ -329,6 +305,15 @@ func parseDevLinkVdpaDevList(msgs [][]byte) ([]VdpaDevice, error) {
if err = dev.parseAttributes(attrs); err != nil {
return nil, err
}

if busName != "" && busName != dev.mgmtDev.busName {
continue
}

if mgmtDeviceName != "" && mgmtDeviceName != dev.mgmtDev.devName {
continue
}

if err = dev.getBusInfo(); err != nil {
return nil, err
}
Expand Down
94 changes: 93 additions & 1 deletion pkg/kvdpa/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,98 @@ func TestVdpaDevList(t *testing.T) {
}
}

func TestVdpaDevListWithFilter(t *testing.T) {
tests := []struct {
name string
err bool
response []VdpaDevice
}{
{
name: "Multiple SR-IOV and SF devices",
err: false,
response: []VdpaDevice{
&vdpaDev{
name: "vdpa0",
mgmtDev: &mgmtDev{
devName: "0000:01:01",
},
},
&vdpaDev{
name: "vdpa1",
mgmtDev: &mgmtDev{
busName: "pci",
devName: "0000:01:02",
},
},
&vdpaDev{
name: "vdpa2",
mgmtDev: &mgmtDev{
busName: "pci",
devName: "0000:01:02",
},
},
&vdpaDev{
name: "vdpa3",
mgmtDev: &mgmtDev{
busName: "pci",
devName: "0000:01:03",
},
},
},
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", "TestVdpaDevListWithFilter", tt.name), func(t *testing.T) {
netLinkMock := &mocks.NetlinkOps{}
SetNetlinkOps(netLinkMock)
netLinkMock.On("RunVdpaNetlinkCmd",
VdpaCmdDevGet,
mock.MatchedBy(func(flags int) bool {
return (flags|syscall.NLM_F_DUMP != 0)
}),
mock.AnythingOfType("[]*nl.RtAttr")).
Return(vdpaDevToNlMessage(t, tt.response...), nil)
// no filters, all devices are returned
devs, err := ListVdpaDevices()
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, tt.response, devs)
}
// mgmtdev: 0000:01:01
devs, err = GetVdpaDevicesByPciAddress("0000:01:01")
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, len(devs), 1)
assert.Equal(t, tt.response[0], devs[0])
}
// mgmtdev: pci/0000:01:02
devs, err = GetVdpaDevicesByPciAddress("pci/0000:01:02")
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, len(devs), 2)
assert.Equal(t, tt.response[1], devs[0])
assert.Equal(t, tt.response[2], devs[1])
}
// mgmtdev: pci/0000:01:03
devs, err = GetVdpaDevicesByPciAddress("pci/0000:01:03")
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, len(devs), 1)
assert.Equal(t, tt.response[3], devs[0])
}
})
}
}

func TestVdpaDevGet(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -304,7 +396,7 @@ func TestVdpaDevGetByMgmt(t *testing.T) {
},
{
name: "Wrong",
err: syscall.ENODEV,
response: []VdpaDevice{},
mgmtDevName: "noDev",
mgmtBusName: "wrongBus",
},
Expand Down
22 changes: 22 additions & 0 deletions pkg/kvdpa/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package kvdpa

import (
"errors"
"strings"
)

// ExtractBusAndMgmtDevice extracts the busName and deviceName from a full device address (e.g. pci)
// example 1: pci/65:0000.1 -> "pci", "65:0000.1", nil
// example 2: vdpa_sim -> "", "vdpa_sim", nil
// example 3: pci/65:0000.1/1 -> "", "", err
func ExtractBusAndMgmtDevice(fullMgmtDeviceName string) (busName string, mgmtDeviceName string, err error) {
numSlashes := strings.Count(fullMgmtDeviceName, "/")
if numSlashes > 1 {
return "", "", errors.New("expected mgmtDeviceName to be either in the format <mgmtBusName>/<mgmtDeviceName> or <mgmtDeviceName>")
} else if numSlashes == 0 {
return "", fullMgmtDeviceName, nil
} else {
values := strings.Split(fullMgmtDeviceName, "/")
return values[0], values[1], nil
}
}
53 changes: 53 additions & 0 deletions pkg/kvdpa/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package kvdpa

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractBusAndMgmtDevice(t *testing.T) {
tests := []struct {
testName string
deviceAddress string
busName string
devName string
err bool
}{
{
testName: "regular PCI address",
deviceAddress: "pci/0000:65:00.1",
busName: "pci",
devName: "0000:65:00.1",
err: false,
},
{
testName: "no bus",
deviceAddress: "vdpa_sim",
busName: "",
devName: "vdpa_sim",
err: false,
},
{
testName: "wrong address",
deviceAddress: "pci/0000:65:00.1/0",
busName: "",
devName: "",
err: true,
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", "TestExtractBusAndMgmtDevice", tt.testName), func(t *testing.T) {
busName, devName, err := ExtractBusAndMgmtDevice(tt.deviceAddress)
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, tt.busName, busName)
assert.Equal(t, tt.devName, devName)
}
})
}
}

0 comments on commit 07c1031

Please sign in to comment.