Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: rdma exlusive handling #603

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions pkg/devices/rdma.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,55 @@
package devices

import (
"github.com/golang/glog"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"

"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils"
)

type rdmaSpec struct {
isSupportRdma bool
deviceSpec []*pluginapi.DeviceSpec
deviceID string
deviceType types.DeviceType
}

func newRdmaSpec(rdmaResources []string) types.RdmaSpec {
// NewRdmaSpec returns the RdmaSpec
func NewRdmaSpec(dt types.DeviceType, id string) types.RdmaSpec {
if dt == types.AcceleratorType {
return nil
}
return &rdmaSpec{deviceID: id, deviceType: dt}
}

func (r *rdmaSpec) IsRdma() bool {
if len(r.getRdmaResources()) > 0 {
return true
}
// Checking for netlink param for exclusive RDMA use case
rdma, err := utils.HasRdmaParam(r.deviceID)
if err != nil {
glog.Infof("HasRdmaParam(): unable to get Netlink RDMA param for device %s : %q", r.deviceID, err)
return false
}
return rdma
}

func (r *rdmaSpec) getRdmaResources() []string {
//nolint: exhaustive
switch r.deviceType {
case types.NetDeviceType:
return utils.GetRdmaProvider().GetRdmaDevicesForPcidev(r.deviceID)
case types.AuxNetDeviceType:
return utils.GetRdmaProvider().GetRdmaDevicesForAuxdev(r.deviceID)
default:
return make([]string, 0)
}
}

func (r *rdmaSpec) GetRdmaDeviceSpec() []*pluginapi.DeviceSpec {
rdmaResources := r.getRdmaResources()
deviceSpec := make([]*pluginapi.DeviceSpec, 0)
isSupportRdma := false
if len(rdmaResources) > 0 {
isSupportRdma = true
for _, res := range rdmaResources {
resRdmaDevices := utils.GetRdmaProvider().GetRdmaCharDevices(res)
for _, rdmaDevice := range resRdmaDevices {
Expand All @@ -45,26 +78,5 @@ func newRdmaSpec(rdmaResources []string) types.RdmaSpec {
}
}
}

return &rdmaSpec{isSupportRdma: isSupportRdma, deviceSpec: deviceSpec}
}

// NewRdmaSpec returns the RdmaSpec for PCI address
func NewRdmaSpec(pciAddr string) types.RdmaSpec {
rdmaResources := utils.GetRdmaProvider().GetRdmaDevicesForPcidev(pciAddr)
return newRdmaSpec(rdmaResources)
}

// NewAuxRdmaSpec returns the RdmaSpec for auxiliary device ID
func NewAuxRdmaSpec(deviceID string) types.RdmaSpec {
rdmaResources := utils.GetRdmaProvider().GetRdmaDevicesForAuxdev(deviceID)
return newRdmaSpec(rdmaResources)
}

func (r *rdmaSpec) IsRdma() bool {
return r.isSupportRdma
}

func (r *rdmaSpec) GetRdmaDeviceSpec() []*pluginapi.DeviceSpec {
return r.deviceSpec
return deviceSpec
}
5 changes: 3 additions & 2 deletions pkg/devices/rdma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"

"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/devices"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils/mocks"
)
Expand All @@ -35,7 +36,7 @@ var _ = Describe("RdmaSpec", func() {
fakeRdmaProvider := mocks.RdmaProvider{}
fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0").Return([]string{})
utils.SetRdmaProviderInst(&fakeRdmaProvider)
spec := devices.NewRdmaSpec("0000:00:00.0")
spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0")

Expect(spec.IsRdma()).To(BeFalse())
Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0))
Expand All @@ -50,7 +51,7 @@ var _ = Describe("RdmaSpec", func() {
"/dev/infiniband/uverbs0", "/dev/infiniband/rdma_cm",
}).On("GetRdmaCharDevices", "fake_1").Return([]string{"/dev/infiniband/rdma_cm"})
utils.SetRdmaProviderInst(&fakeRdmaProvider)
spec := devices.NewRdmaSpec("0000:00:00.0")
spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0")

Expect(spec.IsRdma()).To(BeTrue())
Expect(spec.GetRdmaDeviceSpec()).To(Equal([]*pluginapi.DeviceSpec{
Expand Down
10 changes: 1 addition & 9 deletions pkg/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,7 @@ func (rf *resourceFactory) GetResourcePool(rc *types.ResourceConfig, filteredDev
}

func (rf *resourceFactory) GetRdmaSpec(dt types.DeviceType, deviceID string) types.RdmaSpec {
//nolint: exhaustive
switch dt {
case types.NetDeviceType:
return devices.NewRdmaSpec(deviceID)
case types.AuxNetDeviceType:
return devices.NewAuxRdmaSpec(deviceID)
default:
return nil
}
return devices.NewRdmaSpec(dt, deviceID)
}

func (rf *resourceFactory) GetVdpaDevice(pciAddr string) types.VdpaDevice {
Expand Down
5 changes: 5 additions & 0 deletions pkg/factory/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import (
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types/mocks"
"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils"
utilmocks "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils/mocks"

. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/mock"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

Expand Down Expand Up @@ -606,6 +608,9 @@ var _ = Describe("Factory", func() {
)
Describe("getting rdma spec", func() {
Context("check c rdma spec", func() {
mockProvider := &utilmocks.NetlinkProvider{}
mockProvider.On("HasRdmaParam", mock.AnythingOfType("string")).Return(false, nil)
utils.SetNetlinkProviderInst(mockProvider)
f := factory.NewResourceFactory("fake", "fake", true, false)
rs1 := f.GetRdmaSpec(types.NetDeviceType, "0000:00:00.1")
rs2 := f.GetRdmaSpec(types.AcceleratorType, "0000:00:00.2")
Expand Down
28 changes: 28 additions & 0 deletions pkg/utils/mocks/NetlinkProvider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions pkg/utils/netlink_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type NetlinkProvider interface {
GetIPv4RouteList(ifName string) ([]nl.Route, error)
// DevlinkGetDeviceInfoByNameAsMap returns devlink info for selected device as a map
GetDevlinkGetDeviceInfoByNameAsMap(bus, device string) (map[string]string, error)
// HasRdmaParam returns true if PCI device has "enable_rdma" param
HasRdmaParam(pciAddr string) (bool, error)
}

type defaultNetlinkProvider struct {
Expand All @@ -48,6 +50,19 @@ func GetNetlinkProvider() NetlinkProvider {
return netlinkProvider
}

// HasRdmaParam returns true if PCI device has "enable_rdma" param
// equivalent to "devlink dev param show pci/0000:d8:01.1 name enable_rdma"
func (defaultNetlinkProvider) HasRdmaParam(pciAddr string) (bool, error) {
param, err := nl.DevlinkGetDeviceParamByName("pci", pciAddr, "enable_rdma")
if err != nil {
return false, fmt.Errorf("error getting enable_rdma attribute for pci device %s %v", pciAddr, err)
}
if len(param.Values) == 0 || param.Values[0].Data == nil {
return false, nil
}
return true, nil
}

// GetLinkAttrs returns a net device's link attributes.
func (defaultNetlinkProvider) GetLinkAttrs(ifName string) (*nl.LinkAttrs, error) {
link, err := nl.LinkByName(ifName)
Expand Down
9 changes: 9 additions & 0 deletions pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,15 @@ func GetPfEswitchMode(pciAddr string) (string, error) {
return devLinkDeviceAttrs.Mode, nil
}

// HasRdmaParam returns true if PCI device has "enable_rdma" param
func HasRdmaParam(pciAddr string) (bool, error) {
rdma, err := GetNetlinkProvider().HasRdmaParam(pciAddr)
if err != nil {
return false, err
}
return rdma, nil
}

// HasDefaultRoute returns true if PCI network device is default route interface
func HasDefaultRoute(pciAddr string) (bool, error) {
// Get net interface name
Expand Down
Loading