diff --git a/fou.go b/fou.go index 71e73c37..e3d04326 100644 --- a/fou.go +++ b/fou.go @@ -1,21 +1,16 @@ package netlink import ( - "errors" -) - -var ( - // ErrAttrHeaderTruncated is returned when a netlink attribute's header is - // truncated. - ErrAttrHeaderTruncated = errors.New("attribute header truncated") - // ErrAttrBodyTruncated is returned when a netlink attribute's body is - // truncated. - ErrAttrBodyTruncated = errors.New("attribute body truncated") + "net" ) type Fou struct { Family int - Port int Protocol int EncapType int + Port int + PeerPort int + LocalAddr net.IP + PeerAddr net.IP + IfIndex int } diff --git a/fou_linux.go b/fou_linux.go index ed55b2b7..14970c29 100644 --- a/fou_linux.go +++ b/fou_linux.go @@ -3,8 +3,11 @@ package netlink import ( + "bytes" "encoding/binary" "errors" + "log" + "net" "github.com/vishvananda/netlink/nl" "golang.org/x/sys/unix" @@ -24,12 +27,20 @@ const ( const ( FOU_ATTR_UNSPEC = iota - FOU_ATTR_PORT - FOU_ATTR_AF - FOU_ATTR_IPPROTO - FOU_ATTR_TYPE - FOU_ATTR_REMCSUM_NOPARTIAL - FOU_ATTR_MAX = FOU_ATTR_REMCSUM_NOPARTIAL + + FOU_ATTR_PORT /* u16 */ + FOU_ATTR_AF /* u8 */ + FOU_ATTR_IPPROTO /* u8 */ + FOU_ATTR_TYPE /* u8 */ + FOU_ATTR_REMCSUM_NOPARTIAL /* flag */ + FOU_ATTR_LOCAL_V4 /* u32 */ + FOU_ATTR_LOCAL_V6 /* in6_addr */ + FOU_ATTR_PEER_V4 /* u32 */ + FOU_ATTR_PEER_V6 /* in6_addr */ + FOU_ATTR_PEER_PORT /* u16 */ + FOU_ATTR_IFINDEX /* s32 */ + + FOU_ATTR_MAX ) const ( @@ -60,34 +71,55 @@ func FouAdd(f Fou) error { } func (h *Handle) FouAdd(f Fou) error { - fam_id, err := FouFamilyId() - if err != nil { - return err - } - // setting ip protocol conflicts with encapsulation type GUE if f.EncapType == FOU_ENCAP_GUE && f.Protocol != 0 { return errors.New("GUE encapsulation doesn't specify an IP protocol") } - req := h.newNetlinkRequest(fam_id, unix.NLM_F_ACK) + req, err := h.newFouRequest(unix.NLM_F_ACK, FOU_CMD_ADD) + if err != nil { + return err + } - // int to byte for port + req.AddRtAttr(FOU_ATTR_TYPE, []byte{uint8(f.EncapType)}). + AddRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}). + AddRtAttr(FOU_ATTR_IPPROTO, []byte{uint8(f.Protocol)}) + + // local port bp := make([]byte, 2) binary.BigEndian.PutUint16(bp[0:2], uint16(f.Port)) + req.AddRtAttr(FOU_ATTR_PORT, bp) + + // peer port + if f.PeerPort > 0 { + bp = make([]byte, 2) + binary.BigEndian.PutUint16(bp[0:2], uint16(f.PeerPort)) + req.AddRtAttr(FOU_ATTR_PEER_PORT, bp) + } - attrs := []*nl.RtAttr{ - nl.NewRtAttr(FOU_ATTR_PORT, bp), - nl.NewRtAttr(FOU_ATTR_TYPE, []byte{uint8(f.EncapType)}), - nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}), - nl.NewRtAttr(FOU_ATTR_IPPROTO, []byte{uint8(f.Protocol)}), + // local IP address + if f.LocalAddr != nil && !f.LocalAddr.IsUnspecified() { + if f.Family == nl.FAMILY_V4 { + req.AddRtAttr(FOU_ATTR_LOCAL_V4, []byte(f.LocalAddr.To4())) + } else { + req.AddRtAttr(FOU_ATTR_LOCAL_V6, []byte(f.LocalAddr.To16())) + } } - raw := []byte{FOU_CMD_ADD, 1, 0, 0} - for _, a := range attrs { - raw = append(raw, a.Serialize()...) + + // peer IP address + if f.PeerAddr != nil && !f.PeerAddr.IsUnspecified() { + if f.Family == nl.FAMILY_V4 { + req.AddRtAttr(FOU_ATTR_PEER_V4, []byte(f.PeerAddr.To4())) + } else { + req.AddRtAttr(FOU_ATTR_PEER_V6, []byte(f.PeerAddr.To16())) + } } - req.AddRawData(raw) + if f.IfIndex > 0 { + var buf bytes.Buffer + binary.Write(&buf, nl.NativeEndian(), int32(f.IfIndex)) + req.AddRtAttr(FOU_ATTR_IFINDEX, buf.Bytes()) + } _, err = req.Execute(unix.NETLINK_GENERIC, 0) return err @@ -98,34 +130,42 @@ func FouDel(f Fou) error { } func (h *Handle) FouDel(f Fou) error { - fam_id, err := FouFamilyId() + req, err := h.newFouRequest(unix.NLM_F_ACK, FOU_CMD_DEL) if err != nil { return err } - req := h.newNetlinkRequest(fam_id, unix.NLM_F_ACK) - // int to byte for port bp := make([]byte, 2) binary.BigEndian.PutUint16(bp[0:2], uint16(f.Port)) - attrs := []*nl.RtAttr{ - nl.NewRtAttr(FOU_ATTR_PORT, bp), - nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}), - } - raw := []byte{FOU_CMD_DEL, 1, 0, 0} - for _, a := range attrs { - raw = append(raw, a.Serialize()...) + req.AddRtAttr(FOU_ATTR_PORT, bp). + AddRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}) + + if !f.LocalAddr.IsUnspecified() { + if f.Family == nl.FAMILY_V4 { + req.AddRtAttr(FOU_ATTR_LOCAL_V4, f.LocalAddr.To4()) + } else { + req.AddRtAttr(FOU_ATTR_LOCAL_V6, f.LocalAddr.To16()) + } } - req.AddRawData(raw) + if !f.PeerAddr.IsUnspecified() { + if f.Family == nl.FAMILY_V4 { + req.AddRtAttr(FOU_ATTR_PEER_V4, f.PeerAddr.To4()) + } else { + req.AddRtAttr(FOU_ATTR_PEER_V6, f.PeerAddr.To16()) + } + } - _, err = req.Execute(unix.NETLINK_GENERIC, 0) - if err != nil { - return err + if f.IfIndex > 0 { + buf := make([]byte, 4) + native.PutUint32(buf, uint32(f.IfIndex)) + req.AddRtAttr(FOU_ATTR_IFINDEX, buf) } - return nil + _, err = req.Execute(unix.NETLINK_GENERIC, 0) + return err } func FouList(fam int) ([]Fou, error) { @@ -133,79 +173,57 @@ func FouList(fam int) ([]Fou, error) { } func (h *Handle) FouList(fam int) ([]Fou, error) { - fam_id, err := FouFamilyId() + req, err := h.newFouRequest(unix.NLM_F_DUMP, FOU_CMD_GET) if err != nil { return nil, err } - req := h.newNetlinkRequest(fam_id, unix.NLM_F_DUMP) - - attrs := []*nl.RtAttr{ - nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(fam)}), - } - raw := []byte{FOU_CMD_GET, 1, 0, 0} - for _, a := range attrs { - raw = append(raw, a.Serialize()...) - } - - req.AddRawData(raw) + req.AddRtAttr(FOU_ATTR_AF, []byte{uint8(fam)}) msgs, err := req.Execute(unix.NETLINK_GENERIC, 0) if err != nil { return nil, err } - fous := make([]Fou, 0, len(msgs)) - for _, m := range msgs { - f, err := deserializeFouMsg(m) - if err != nil { - return fous, err - } - - fous = append(fous, f) + fous := make([]Fou, len(msgs)) + for i := range msgs { + fous[i] = deserializeFouMsg(msgs[i]) } return fous, nil } -func deserializeFouMsg(msg []byte) (Fou, error) { - // we'll skip to byte 4 to first attribute - msg = msg[3:] - var shift int - fou := Fou{} - - for { - // attribute header is at least 16 bits - if len(msg) < 4 { - return fou, ErrAttrHeaderTruncated - } +func (h *Handle) newFouRequest(flags int, cmd uint8) (*nl.NetlinkRequest, error) { + familyID, err := FouFamilyId() + if err != nil { + return nil, err + } - lgt := int(binary.BigEndian.Uint16(msg[0:2])) - if len(msg) < lgt+4 { - return fou, ErrAttrBodyTruncated - } - attr := binary.BigEndian.Uint16(msg[2:4]) + return h.newNetlinkRequest(familyID, flags).AddRawData([]byte{cmd, 1, 0, 0}), nil +} - shift = lgt + 3 - switch attr { +func deserializeFouMsg(msg []byte) (fou Fou) { + for attr := range nl.ParseAttributes(msg[4:]) { + switch attr.Type { case FOU_ATTR_AF: - fou.Family = int(msg[5]) + fou.Family = int(attr.Value[0]) case FOU_ATTR_PORT: - fou.Port = int(binary.BigEndian.Uint16(msg[5:7])) - // port is 2 bytes - shift = lgt + 2 + fou.Port = int(binary.BigEndian.Uint16(attr.Value)) + case FOU_ATTR_PEER_PORT: + fou.PeerPort = int(binary.BigEndian.Uint16(attr.Value)) case FOU_ATTR_IPPROTO: - fou.Protocol = int(msg[5]) + fou.Protocol = int(attr.Value[0]) case FOU_ATTR_TYPE: - fou.EncapType = int(msg[5]) - } - - msg = msg[shift:] - - if len(msg) < 4 { - break + fou.EncapType = int(attr.Value[0]) + case FOU_ATTR_LOCAL_V4, FOU_ATTR_LOCAL_V6: + fou.LocalAddr = net.IP(attr.Value) + case FOU_ATTR_PEER_V4, FOU_ATTR_PEER_V6: + fou.PeerAddr = net.IP(attr.Value) + case FOU_ATTR_IFINDEX: + fou.IfIndex = int(attr.Int32()) + default: + log.Printf("unknown attribute: %02x", attr) } } - - return fou, nil + return } diff --git a/fou_test.go b/fou_test.go index b252bbad..96db5327 100644 --- a/fou_test.go +++ b/fou_test.go @@ -6,47 +6,45 @@ import ( "testing" ) -func TestFouDeserializeMsg(t *testing.T) { +func TestFouDeserializeMsgEncapDirect(t *testing.T) { var msg []byte // deserialize a valid message msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0, 0, 6, 0, 1, 0, 21, 179, 0, 0, 5, 0, 3, 0, 4, 0, 0, 0, 5, 0, 4, 0, 1, 0, 0, 0} - if fou, err := deserializeFouMsg(msg); err != nil { - t.Error(err.Error()) - } else { + fou := deserializeFouMsg(msg) - // check if message was deserialized correctly - if fou.Family != FAMILY_V4 { - t.Errorf("expected family %d, got %d", FAMILY_V4, fou.Family) - } + // check if message was deserialized correctly + if fou.Family != FAMILY_V4 { + t.Errorf("expected family %d, got %d", FAMILY_V4, fou.Family) + } - if fou.Port != 5555 { - t.Errorf("expected port 5555, got %d", fou.Port) - } + if fou.Port != 5555 { + t.Errorf("expected port 5555, got %d", fou.Port) + } - if fou.Protocol != 4 { // ipip - t.Errorf("expected protocol 4, got %d", fou.Protocol) - } + if fou.Protocol != 4 { // ipip + t.Errorf("expected protocol 4, got %d", fou.Protocol) + } - if fou.EncapType != FOU_ENCAP_DIRECT { - t.Errorf("expected encap type %d, got %d", FOU_ENCAP_DIRECT, fou.EncapType) - } + if fou.EncapType != FOU_ENCAP_DIRECT { + t.Errorf("expected encap type %d, got %d", FOU_ENCAP_DIRECT, fou.EncapType) } // deserialize truncated attribute header msg = []byte{3, 1, 0, 0, 5, 0} - if _, err := deserializeFouMsg(msg); err == nil { - t.Error("expected attribute header truncated error") - } else if err != ErrAttrHeaderTruncated { - t.Errorf("unexpected error: %s", err.Error()) + fou = deserializeFouMsg(msg) + if fou.Family != 0 { + t.Errorf("expected family 0, got %d", fou.Family) } // deserialize truncated attribute header msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0} - if _, err := deserializeFouMsg(msg); err == nil { - t.Error("expected attribute body truncated error") - } else if err != ErrAttrBodyTruncated { - t.Errorf("unexpected error: %s", err.Error()) + fou = deserializeFouMsg(msg) + if fou.Family != 2 { + t.Errorf("expected family 2, got %d", fou.Family) + } + if fou.Protocol != 0 { + t.Errorf("expected protocol 0, got %d", fou.Protocol) } } diff --git a/nl/parse_attr.go b/nl/parse_attr.go index 7f49125c..06637ae2 100644 --- a/nl/parse_attr.go +++ b/nl/parse_attr.go @@ -1,6 +1,7 @@ package nl import ( + "bytes" "encoding/binary" "fmt" "log" @@ -60,6 +61,12 @@ func printAttributes(data []byte, level int) { } } +// Int32 returns the int32 native endian value +func (attr *Attribute) Int32() (ret int32) { + binary.Read(bytes.NewBuffer(attr.Value), NativeEndian(), &ret) + return +} + // Uint32 returns the uint32 value respecting the NET_BYTEORDER flag func (attr *Attribute) Uint32() uint32 { if attr.Type&NLA_F_NET_BYTEORDER != 0 {