Skip to content

Commit

Permalink
Add additional attributes for fou
Browse files Browse the repository at this point in the history
  • Loading branch information
corny committed Oct 22, 2020
1 parent 0f0399b commit 2729627
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 125 deletions.
17 changes: 6 additions & 11 deletions fou.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
package netlink

import (
"errors"
)

var (
// ErrAttrHeaderTruncated is returned when a netlink attribute's header is
// truncated.
ErrAttrHeaderTruncated = errors.New("attribute header truncated")
// ErrAttrBodyTruncated is returned when a netlink attribute's body is
// truncated.
ErrAttrBodyTruncated = errors.New("attribute body truncated")
"net"
)

type Fou struct {
Family int
Port int
Protocol int
EncapType int
Port int
PeerPort int
LocalAddr net.IP
PeerAddr net.IP
IfIndex int
}
196 changes: 107 additions & 89 deletions fou_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
package netlink

import (
"bytes"
"encoding/binary"
"errors"
"log"
"net"

"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
Expand All @@ -24,12 +27,20 @@ const (

const (
FOU_ATTR_UNSPEC = iota
FOU_ATTR_PORT
FOU_ATTR_AF
FOU_ATTR_IPPROTO
FOU_ATTR_TYPE
FOU_ATTR_REMCSUM_NOPARTIAL
FOU_ATTR_MAX = FOU_ATTR_REMCSUM_NOPARTIAL

FOU_ATTR_PORT /* u16 */
FOU_ATTR_AF /* u8 */
FOU_ATTR_IPPROTO /* u8 */
FOU_ATTR_TYPE /* u8 */
FOU_ATTR_REMCSUM_NOPARTIAL /* flag */
FOU_ATTR_LOCAL_V4 /* u32 */
FOU_ATTR_LOCAL_V6 /* in6_addr */
FOU_ATTR_PEER_V4 /* u32 */
FOU_ATTR_PEER_V6 /* in6_addr */
FOU_ATTR_PEER_PORT /* u16 */
FOU_ATTR_IFINDEX /* s32 */

FOU_ATTR_MAX
)

const (
Expand Down Expand Up @@ -60,34 +71,55 @@ func FouAdd(f Fou) error {
}

func (h *Handle) FouAdd(f Fou) error {
fam_id, err := FouFamilyId()
if err != nil {
return err
}

// setting ip protocol conflicts with encapsulation type GUE
if f.EncapType == FOU_ENCAP_GUE && f.Protocol != 0 {
return errors.New("GUE encapsulation doesn't specify an IP protocol")
}

req := h.newNetlinkRequest(fam_id, unix.NLM_F_ACK)
req, err := h.newFouRequest(unix.NLM_F_ACK, FOU_CMD_ADD)
if err != nil {
return err
}

// int to byte for port
req.AddRtAttr(FOU_ATTR_TYPE, []byte{uint8(f.EncapType)}).
AddRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}).
AddRtAttr(FOU_ATTR_IPPROTO, []byte{uint8(f.Protocol)})

// local port
bp := make([]byte, 2)
binary.BigEndian.PutUint16(bp[0:2], uint16(f.Port))
req.AddRtAttr(FOU_ATTR_PORT, bp)

// peer port
if f.PeerPort > 0 {
bp = make([]byte, 2)
binary.BigEndian.PutUint16(bp[0:2], uint16(f.PeerPort))
req.AddRtAttr(FOU_ATTR_PEER_PORT, bp)
}

attrs := []*nl.RtAttr{
nl.NewRtAttr(FOU_ATTR_PORT, bp),
nl.NewRtAttr(FOU_ATTR_TYPE, []byte{uint8(f.EncapType)}),
nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}),
nl.NewRtAttr(FOU_ATTR_IPPROTO, []byte{uint8(f.Protocol)}),
// local IP address
if f.LocalAddr != nil && !f.LocalAddr.IsUnspecified() {
if f.Family == nl.FAMILY_V4 {
req.AddRtAttr(FOU_ATTR_LOCAL_V4, []byte(f.LocalAddr.To4()))
} else {
req.AddRtAttr(FOU_ATTR_LOCAL_V6, []byte(f.LocalAddr.To16()))
}
}
raw := []byte{FOU_CMD_ADD, 1, 0, 0}
for _, a := range attrs {
raw = append(raw, a.Serialize()...)

// peer IP address
if f.PeerAddr != nil && !f.PeerAddr.IsUnspecified() {
if f.Family == nl.FAMILY_V4 {
req.AddRtAttr(FOU_ATTR_PEER_V4, []byte(f.PeerAddr.To4()))
} else {
req.AddRtAttr(FOU_ATTR_PEER_V6, []byte(f.PeerAddr.To16()))
}
}

req.AddRawData(raw)
if f.IfIndex > 0 {
var buf bytes.Buffer
binary.Write(&buf, nl.NativeEndian(), int32(f.IfIndex))
req.AddRtAttr(FOU_ATTR_IFINDEX, buf.Bytes())
}

_, err = req.Execute(unix.NETLINK_GENERIC, 0)
return err
Expand All @@ -98,114 +130,100 @@ func FouDel(f Fou) error {
}

func (h *Handle) FouDel(f Fou) error {
fam_id, err := FouFamilyId()
req, err := h.newFouRequest(unix.NLM_F_ACK, FOU_CMD_DEL)
if err != nil {
return err
}

req := h.newNetlinkRequest(fam_id, unix.NLM_F_ACK)

// int to byte for port
bp := make([]byte, 2)
binary.BigEndian.PutUint16(bp[0:2], uint16(f.Port))

attrs := []*nl.RtAttr{
nl.NewRtAttr(FOU_ATTR_PORT, bp),
nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}),
}
raw := []byte{FOU_CMD_DEL, 1, 0, 0}
for _, a := range attrs {
raw = append(raw, a.Serialize()...)
req.AddRtAttr(FOU_ATTR_PORT, bp).
AddRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)})

if !f.LocalAddr.IsUnspecified() {
if f.Family == nl.FAMILY_V4 {
req.AddRtAttr(FOU_ATTR_LOCAL_V4, f.LocalAddr.To4())
} else {
req.AddRtAttr(FOU_ATTR_LOCAL_V6, f.LocalAddr.To16())
}
}

req.AddRawData(raw)
if !f.PeerAddr.IsUnspecified() {
if f.Family == nl.FAMILY_V4 {
req.AddRtAttr(FOU_ATTR_PEER_V4, f.PeerAddr.To4())
} else {
req.AddRtAttr(FOU_ATTR_PEER_V6, f.PeerAddr.To16())
}
}

_, err = req.Execute(unix.NETLINK_GENERIC, 0)
if err != nil {
return err
if f.IfIndex > 0 {
buf := make([]byte, 4)
native.PutUint32(buf, uint32(f.IfIndex))
req.AddRtAttr(FOU_ATTR_IFINDEX, buf)
}

return nil
_, err = req.Execute(unix.NETLINK_GENERIC, 0)
return err
}

func FouList(fam int) ([]Fou, error) {
return pkgHandle.FouList(fam)
}

func (h *Handle) FouList(fam int) ([]Fou, error) {
fam_id, err := FouFamilyId()
req, err := h.newFouRequest(unix.NLM_F_DUMP, FOU_CMD_GET)
if err != nil {
return nil, err
}

req := h.newNetlinkRequest(fam_id, unix.NLM_F_DUMP)

attrs := []*nl.RtAttr{
nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(fam)}),
}
raw := []byte{FOU_CMD_GET, 1, 0, 0}
for _, a := range attrs {
raw = append(raw, a.Serialize()...)
}

req.AddRawData(raw)
req.AddRtAttr(FOU_ATTR_AF, []byte{uint8(fam)})

msgs, err := req.Execute(unix.NETLINK_GENERIC, 0)
if err != nil {
return nil, err
}

fous := make([]Fou, 0, len(msgs))
for _, m := range msgs {
f, err := deserializeFouMsg(m)
if err != nil {
return fous, err
}

fous = append(fous, f)
fous := make([]Fou, len(msgs))
for i := range msgs {
fous[i] = deserializeFouMsg(msgs[i])
}

return fous, nil
}

func deserializeFouMsg(msg []byte) (Fou, error) {
// we'll skip to byte 4 to first attribute
msg = msg[3:]
var shift int
fou := Fou{}

for {
// attribute header is at least 16 bits
if len(msg) < 4 {
return fou, ErrAttrHeaderTruncated
}
func (h *Handle) newFouRequest(flags int, cmd uint8) (*nl.NetlinkRequest, error) {
familyID, err := FouFamilyId()
if err != nil {
return nil, err
}

lgt := int(binary.BigEndian.Uint16(msg[0:2]))
if len(msg) < lgt+4 {
return fou, ErrAttrBodyTruncated
}
attr := binary.BigEndian.Uint16(msg[2:4])
return h.newNetlinkRequest(familyID, flags).AddRawData([]byte{cmd, 1, 0, 0}), nil
}

shift = lgt + 3
switch attr {
func deserializeFouMsg(msg []byte) (fou Fou) {
for attr := range nl.ParseAttributes(msg[4:]) {
switch attr.Type {
case FOU_ATTR_AF:
fou.Family = int(msg[5])
fou.Family = int(attr.Value[0])
case FOU_ATTR_PORT:
fou.Port = int(binary.BigEndian.Uint16(msg[5:7]))
// port is 2 bytes
shift = lgt + 2
fou.Port = int(binary.BigEndian.Uint16(attr.Value))
case FOU_ATTR_PEER_PORT:
fou.PeerPort = int(binary.BigEndian.Uint16(attr.Value))
case FOU_ATTR_IPPROTO:
fou.Protocol = int(msg[5])
fou.Protocol = int(attr.Value[0])
case FOU_ATTR_TYPE:
fou.EncapType = int(msg[5])
}

msg = msg[shift:]

if len(msg) < 4 {
break
fou.EncapType = int(attr.Value[0])
case FOU_ATTR_LOCAL_V4, FOU_ATTR_LOCAL_V6:
fou.LocalAddr = net.IP(attr.Value)
case FOU_ATTR_PEER_V4, FOU_ATTR_PEER_V6:
fou.PeerAddr = net.IP(attr.Value)
case FOU_ATTR_IFINDEX:
fou.IfIndex = int(attr.Int32())
default:
log.Printf("unknown attribute: %02x", attr)
}
}

return fou, nil
return
}
48 changes: 23 additions & 25 deletions fou_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,45 @@ import (
"testing"
)

func TestFouDeserializeMsg(t *testing.T) {
func TestFouDeserializeMsgEncapDirect(t *testing.T) {
var msg []byte

// deserialize a valid message
msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0, 0, 6, 0, 1, 0, 21, 179, 0, 0, 5, 0, 3, 0, 4, 0, 0, 0, 5, 0, 4, 0, 1, 0, 0, 0}
if fou, err := deserializeFouMsg(msg); err != nil {
t.Error(err.Error())
} else {
fou := deserializeFouMsg(msg)

// check if message was deserialized correctly
if fou.Family != FAMILY_V4 {
t.Errorf("expected family %d, got %d", FAMILY_V4, fou.Family)
}
// check if message was deserialized correctly
if fou.Family != FAMILY_V4 {
t.Errorf("expected family %d, got %d", FAMILY_V4, fou.Family)
}

if fou.Port != 5555 {
t.Errorf("expected port 5555, got %d", fou.Port)
}
if fou.Port != 5555 {
t.Errorf("expected port 5555, got %d", fou.Port)
}

if fou.Protocol != 4 { // ipip
t.Errorf("expected protocol 4, got %d", fou.Protocol)
}
if fou.Protocol != 4 { // ipip
t.Errorf("expected protocol 4, got %d", fou.Protocol)
}

if fou.EncapType != FOU_ENCAP_DIRECT {
t.Errorf("expected encap type %d, got %d", FOU_ENCAP_DIRECT, fou.EncapType)
}
if fou.EncapType != FOU_ENCAP_DIRECT {
t.Errorf("expected encap type %d, got %d", FOU_ENCAP_DIRECT, fou.EncapType)
}

// deserialize truncated attribute header
msg = []byte{3, 1, 0, 0, 5, 0}
if _, err := deserializeFouMsg(msg); err == nil {
t.Error("expected attribute header truncated error")
} else if err != ErrAttrHeaderTruncated {
t.Errorf("unexpected error: %s", err.Error())
fou = deserializeFouMsg(msg)
if fou.Family != 0 {
t.Errorf("expected family 0, got %d", fou.Family)
}

// deserialize truncated attribute header
msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0}
if _, err := deserializeFouMsg(msg); err == nil {
t.Error("expected attribute body truncated error")
} else if err != ErrAttrBodyTruncated {
t.Errorf("unexpected error: %s", err.Error())
fou = deserializeFouMsg(msg)
if fou.Family != 2 {
t.Errorf("expected family 2, got %d", fou.Family)
}
if fou.Protocol != 0 {
t.Errorf("expected protocol 0, got %d", fou.Protocol)
}
}

Expand Down
7 changes: 7 additions & 0 deletions nl/parse_attr.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nl

import (
"bytes"
"encoding/binary"
"fmt"
"log"
Expand Down Expand Up @@ -60,6 +61,12 @@ func printAttributes(data []byte, level int) {
}
}

// Int32 returns the int32 native endian value
func (attr *Attribute) Int32() (ret int32) {
binary.Read(bytes.NewBuffer(attr.Value), NativeEndian(), &ret)
return
}

// Uint32 returns the uint32 value respecting the NET_BYTEORDER flag
func (attr *Attribute) Uint32() uint32 {
if attr.Type&NLA_F_NET_BYTEORDER != 0 {
Expand Down

0 comments on commit 2729627

Please sign in to comment.